Open In Colab

1. Two moons with an invertible Neural Network¶

In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels
import matplotlib.pyplot as plt
In [ ]:
### modified coupling layer class
import torch
import torch.nn as nn
import torch.nn.functional as F

class CouplingLayer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CouplingLayer, self).__init__()
        # Neural networks for the first half of the dimensions
        self.fc1 = nn.Linear(input_size // 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        # Translation coefficient
        self.fc3 = nn.Linear(hidden_size, input_size // 2)
        # Scaling coefficient
        self.fc4 = nn.Linear(hidden_size, input_size // 2)

    def forward(self, x):
        # Split the input into two halves
        x_a, x_b = x.chunk(2, dim=1)

        # Apply neural network to calculate coefficients
        h = F.relu(self.fc1(x_a))
        h = F.relu(self.fc2(h))
        translation = self.fc3(h)
        scaling_before_exp = torch.tanh(self.fc4(h))### taking the tanh
        scaling = torch.exp(scaling_before_exp)


        # Apply the affine transformation
        y_b = x_b * scaling + translation

        # Concatenate the transformed halves
        y = torch.cat([x_a, y_b], dim=1)
        return y, scaling_before_exp

    def backward(self, y):
        # Split the input into two halves
        y_a, y_b = y.chunk(2, dim=1)

        # Apply neural network to calculate coefficients (reverse)
        h = F.relu(self.fc1(y_a))
        h = F.relu(self.fc2(h))
        translation = self.fc3(h)
        scaling_before_exp = self.fc4(h)
        scaling = torch.exp(torch.tanh(scaling_before_exp))

        # Reverse the operations to reconstruct the original input
        x_a = y_a
        x_b = (y_b - translation) / scaling

        # Concatenate the reconstructed halves
        x = torch.cat([x_a, x_b], dim=1)
        return x
In [ ]:
class RealNVP(nn.Module):
    def __init__(self, input_size, hidden_size, blocks):
        super(RealNVP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.blocks = blocks

        # List of coupling layers
        self.coupling_layers = nn.ModuleList([
            CouplingLayer(input_size, hidden_size) for _ in range(blocks)
        ])


        # List to store orthonormal matrices
        self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]

        # List to store scaling_before_exp for each block
        self.scaling_before_exp_list = []

    def _get_orthonormal_matrix(self, size):
        # Function to generate a random orthonormal matrix
        w = torch.randn(size, size)
        q, _ = torch.linalg.qr(w,'reduced')
        return q

    def forward_realnvp(self, x):
        scaling_before_exp_list = []
        for i in range(self.blocks):

            # Apply random orthonormal matrix
            x = torch.matmul(x, self.orthonormal_matrices[i])

            # Apply coupling layer
            x, scaling_before_exp = self.coupling_layers[i].forward(x)
            scaling_before_exp_list.append(scaling_before_exp)

        self.scaling_before_exp_list = scaling_before_exp_list
        return x

    def encode(self, x):
        # Encoding is the forward pass through the RealNVP model
        return self.forward_realnvp(x)

    def decode(self, z):
        # Reverse transformations for decoding
        for i in reversed(range(self.blocks)):

            # Apply coupling layer (reverse)
            z = self.coupling_layers[i].backward(z)

            # Apply random orthonormal matrix (reverse)
            z = torch.matmul(z, self.orthonormal_matrices[i].t())
        return z

    def sample(self, num_samples=1000):
        # Generate random samples from a standard normal distribution
        with torch.no_grad():
            z = torch.randn(num_samples, self.input_size)

        # Apply the reverse transformations (decoder) to generate synthetic samples
        synthetic_samples = self.decode(z)
        return synthetic_samples
In [ ]:
### defining our loss function
def calculate_loss(transformed_x, scaling_before_exp_list, dataset_length):
    """
    Calculate the loss for the RealNVP model.

    Args:
    - transformed_x (tensor): Transformed data produced by the RealNVP model.
    - scaling_before_exp_list (list): List of scaling_before_exp values for each block.
    - dataset_length (int): The length of the dataset.

    Returns:
    - loss (tensor): The calculated loss value.
    """

    # Calculate the first term of the loss (negative log-likelihood term)
    first_term = 0.5*torch.sum(transformed_x**2)

    second_term= -torch.sum(torch.cat(scaling_before_exp_list))#torch.sum(torch.stack(model.scaling_before_exp_list), dim=0)

    # Calculate the total loss
    loss = (first_term + second_term) / dataset_length

    return loss
In [ ]:
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
    """
    Train the RealNVP model and evaluate on a validation dataset.

    Args:
    - model (RealNVP): The RealNVP model to be trained.
    - train_loader (DataLoader): DataLoader for the training dataset.
    - val_loader (DataLoader): DataLoader for the validation dataset.
    - num_epochs (int): Number of training epochs.
    - lr (float): Learning rate for the optimizer.
    - print_after (int): Number of epochs after which to print the training and validation loss.

    Returns:
    - train_losses (list): List of training losses for each epoch.
    - val_losses (list): List of validation losses for each epoch.
    """

    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []  # List to store training losses
    val_losses = []    # List to store validation losses

    for epoch in range(num_epochs):
        total_train_loss = 0.0

        # Training phase
        model.train()  # Set the model to training mode
        for data in train_loader:
            inputs= data

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass (encoding)
            encoded = model.encode(inputs)

            # Loss calculation
            train_loss = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))

            # Backward pass (gradient computation)
            train_loss.backward()

            ### added recently: clip the gradients
            clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            # Update weights
            optimizer.step()

            total_train_loss += train_loss.item()

        # Average training loss for the epoch
        average_train_loss = total_train_loss / len(train_loader)

        # Validation phase
        if val_loader is not None:
            model.eval()  # Set the model to evaluation mode
            total_val_loss = 0.0
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs = val_data

                    # Forward pass (encoding) for validation
                    val_encoded = model.encode(val_inputs)

                    # Loss calculation for validation
                    val_loss = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))

                    total_val_loss += val_loss.item()

            # Average validation loss for the epoch
            average_val_loss = total_val_loss / len(val_loader)

            # Print training and validation losses together
            if (epoch + 1) % print_after == 0:
                print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss}, Validation Loss: {average_val_loss}")

            # Append losses to the lists
            train_losses.append(average_train_loss)
            val_losses.append(average_val_loss)

        # Set the model back to training mode
        model.train()

    print("Training complete")

    return train_losses, val_losses

Some helper functions for plotting¶

In [ ]:
# function to plot training and validation losses
def plot_losses(epoch_train_losses, epoch_val_losses, want_log_scale=True):
    """
    Plot training and validation losses over epochs on a log scale.

    Args:
        epoch_train_losses (list): List of training losses for each epoch.
        epoch_val_losses (list): List of validation losses for each epoch.
    """
    epochs = range(1, len(epoch_train_losses) + 1)

    plt.plot(epochs, epoch_train_losses, label='Training Loss')
    plt.plot(epochs, epoch_val_losses, label='Validation Loss')

    if want_log_scale:
      plt.yscale('log')  # Set the y-axis to a logarithmic scale
      plt.title('Training and Validation reconstruction Losses (Log Scale)',fontsize=10)
    else:
      plt.title('Training and Validation reconstruction Losses',fontsize=10)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
In [ ]:
def visualize_synthetic_data(original_data, synthetic_data):
    """
    Scatter plot to visualize the original and synthetic data in 2D.

    Args:
    - original_data (torch.Tensor): Original data (2D).
    - synthetic_data (torch.Tensor): Synthetic data (2D).

    Returns:
    - None: Displays the scatter plot.
    """
    # Ensure both original and synthetic data are converted to numpy arrays
    with torch.no_grad():
        # Convert PyTorch tensors to numpy arrays
        original_np = original_data.numpy()
        synthetic_np = synthetic_data.numpy()

        # Scatter plot of original and synthetic data
        plt.scatter(original_np[:, 0], original_np[:, 1], label='Original', alpha=0.5)
        plt.scatter(synthetic_np[:, 0], synthetic_np[:, 1], label='Synthetic', alpha=0.5)

        # Add labels and title
        plt.xlabel("dimension-1")
        plt.ylabel("dimension-2")
        plt.title('Original vs Synthetic Data')

        # Add legend
        plt.legend()

        # Display the plot
        #plt.show()
In [ ]:
def plot_code_distribution(model, test_loader, num_samples=1000):
    """
    Plot the code distribution obtained by applying the trained RealNVP model to a test dataset.

    Args:
    - model (RealNVP): Trained RealNVP model.
    - test_loader (DataLoader): DataLoader for the test dataset.
    - num_samples (int): Number of samples to visualize.

    Returns:
    None (displays the plot).
    """
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        # Concatenate multiple batches to obtain more samples
        test_samples = torch.cat([batch for batch in test_loader], dim=0)

        # Assuming your model has an `encode` method
        code_samples = model.encode(test_samples[:num_samples])

        # Convert PyTorch tensor to numpy array
        code_np = code_samples.numpy()

        # Scatter plot of code distribution
        plt.scatter(code_np[:, 0], code_np[:, 1], label='Code Distribution', alpha=0.5)
        plt.xlabel("Code Dimension 1")
        plt.ylabel("Code Dimension 2")
        plt.title('Code Distribution')
        plt.legend()
        #plt.show()

loading the two moons dataset¶

In [ ]:
dataset_sizes = [ 100, 200,300,400,500,600,700,800,900, 1000, 5000]

# Generate datasets of varying sizes
train_datasets = {}
val_datasets = {}
datasets = {}

for size in dataset_sizes:
    X, y = make_moons(n_samples=size, noise=0.1)
    datasets[size] = {'X': X, 'y': y}
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    train_datasets[size] = {'X': torch.FloatTensor(X_train), 'y': y_train}
    val_datasets[size] = {'X': torch.FloatTensor(X_test), 'y': y_test}

# # Visualize the training datasets
# plt.figure(figsize=(12, 8))

# for i, size in enumerate(dataset_sizes, 1):
#     plt.subplot(2, 2, i)
#     plt.scatter(datasets[size]['X'][:, 0], datasets[size]['X'][:, 1], c=datasets[size]['y'])
#     plt.title(f'Dataset Size: {size}')

# plt.show()
In [ ]:
### creating the dataloader for the make moons dataset
from torch.utils.data import DataLoader, TensorDataset

### Trial run
import numpy as np
input_size=2
hidden_size=200### do I really need this to be this large?
blocks=10 ####### larger number of blocks ensures that the code distribution is indeed gaussian
print_after=1

#### data for the two-moons model
dataset_size=5000
batch_size=32
data_considered=train_datasets[dataset_size]['X']
print("shape of the data_considered"); print(data_considered.shape)
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=batch_size, shuffle=True)
val_loader= torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=batch_size, shuffle=True)
####

### instantiate the model
model= RealNVP(input_size=2, hidden_size= hidden_size, blocks=blocks)

## train the model
train_losses, val_losses= train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.0001, print_after=1)
#train_inn(model, train_loader, num_epochs=500, lr=0.01, print_after=10)
#1. 0.00005 num_epochs=20,dataset_size=5000, batchsize=64: right now I have kept blocks=10: code distribution was more gaussian and generated data was comparitively better

# plotting the loss
plot_losses(train_losses[3:], val_losses[3:], want_log_scale=0)
plt.show()

# Example usage:
plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
plt.show()

### plot the synthetic data and the original data
synthetic_data=model.sample(num_samples=1000)
visualize_synthetic_data(train_datasets[1000]['X'], synthetic_data)
plt.show()
shape of the data_considered
torch.Size([3500, 2])
Epoch 1/10, Training Loss: -0.069945588813756, Validation Loss: -0.2947036237158674
Epoch 2/10, Training Loss: -0.13655666284559465, Validation Loss: -0.3800994145109298
Epoch 3/10, Training Loss: -0.15644234075126323, Validation Loss: -0.4005805108141392
Epoch 4/10, Training Loss: -0.15004104049876332, Validation Loss: -0.3940854034525283
Epoch 5/10, Training Loss: -0.17315212689678777, Validation Loss: -0.42223459038328615
Epoch 6/10, Training Loss: -0.17323222608220848, Validation Loss: -0.34632147388889434
Epoch 7/10, Training Loss: -0.17355674239383503, Validation Loss: -0.4644995800992276
Epoch 8/10, Training Loss: -0.18419327680021524, Validation Loss: -0.3325619826767039
Epoch 9/10, Training Loss: -0.18445073308592494, Validation Loss: -0.45367521745093325
Epoch 10/10, Training Loss: -0.18611254140057348, Validation Loss: -0.40497843128569583
Training complete

1.1 Effect of number of coupling blocks in the network:¶

In [ ]:
import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels
import matplotlib.pyplot as plt

def compute_mmd(X, Y, kernel='rbf', gamma=None):
    """
    Compute Maximum Mean Discrepancy (MMD) between two datasets.

    Parameters:
    - X, Y: Input datasets (numpy arrays).
    - kernel: Kernel function to use ('linear', 'rbf', etc.).
    - gamma: Kernel coefficient for 'rbf' kernel (if applicable).

    Returns:
    - mmd: Maximum Mean Discrepancy value.
    """

    X = X.detach().numpy() if isinstance(X, torch.Tensor) else X
    Y = Y.detach().numpy() if isinstance(Y, torch.Tensor) else Y

    # Compute pairwise kernel matrices
    K_xx = pairwise_kernels(X, X, metric=kernel, gamma=gamma)
    K_yy = pairwise_kernels(Y, Y, metric=kernel, gamma=gamma)
    K_xy = pairwise_kernels(X, Y, metric=kernel, gamma=gamma)

    # Compute MMD
    mmd = np.mean(K_xx) + np.mean(K_yy) - 2 * np.mean(K_xy)
    return mmd
In [ ]:
### Input_size=2, hidden_size=200, lr=0.0001, num_epochs=10: Fixed
def train_and_plot_for_different_block_sizes(blocks_values, train_loader, val_loader):
    results = []

    for blocks in blocks_values:
        print(f"\nTraining for blocks={blocks}")
        # Instantiate the model
        model = RealNVP(input_size=2, hidden_size=200, blocks=blocks)

        # Train the model
        train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.0001, print_after=100)

        # Plot code distribution
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.xlim(-3,3)
        plt.ylim(-3,3)
        plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
        plt.title(f'Code Distribution (blocks={blocks})')

        # Plot synthetic data
        plt.subplot(1, 2, 2)
        synthetic_data = model.sample(num_samples=1000)
        visualize_synthetic_data(train_datasets[1000]['X'], synthetic_data)
        plt.title(f'Synthetic Data (blocks={blocks})')

        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()

        # # Calculate MMD score
        mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
        print(f'MMD Score (blocks={blocks}): {mmd_value:.4f}')

        results.append((blocks, mmd_value))

    # # Plot MMD scores
    plt.figure(figsize=(8, 5))
    blocks, mmd_values = zip(*results)
    plt.plot(blocks, mmd_values, marker='o')
    plt.title('MMD Scores for Different Number of Blocks')
    plt.yscale('log')
    plt.xlabel('Blocks')
    plt.ylabel('MMD Score')
    plt.show()


######### for different coupling blocks


dataset_size=5000
print(f"For fixed dataset_size={dataset_size},  hidden_size=200, lr=0.0001, num_epochs=10")
batch_size=32
data_considered=train_datasets[dataset_size]['X']
print("shape of the data_considered"); print(data_considered.shape)
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=batch_size, shuffle=True)
val_loader= torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=batch_size, shuffle=True)
####

blocks_values_to_try = [1,2,10,15]
train_and_plot_for_different_block_sizes(blocks_values_to_try, train_loader, val_loader)
For fixed dataset_size=5000,  hidden_size=200, lr=0.0001, num_epochs=10
shape of the data_considered
torch.Size([3500, 2])

Training for blocks=1
Training complete
MMD Score (blocks=1): 0.0859

Training for blocks=2
Training complete
MMD Score (blocks=2): 0.0092

Training for blocks=10
Training complete
MMD Score (blocks=10): 0.0040

Training for blocks=15
Training complete
MMD Score (blocks=15): 0.0008

Observation:¶

  1. Number of coupling blocks:

    1. We observed that the higher number of coupling blocks ensures that the code distribution is gaussian. We also observed improvement in the quality of synthetic data.

1.2 Effect of training set size¶

In [ ]:
def train_and_plot_for_different_dataset_sizes(dataset_sizes, train_loader, val_loader):
    results = []

    for dataset_size in dataset_sizes:
        print(f"\nTraining for dataset_size={dataset_size}")

        # Instantiate the model
        model = RealNVP(input_size=2, hidden_size=200, blocks=10)  # Fix other parameters

        # Create data loader for the current dataset size
        data_considered = train_datasets[dataset_size]['X']
        train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True)

        # Train the model
        train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.0001, print_after=100)

        # Plot code distribution
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.xlim(-3, 3)
        plt.ylim(-3, 3)
        plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
        plt.title(f'Code Distribution (dataset_size={dataset_size})')

        # Plot synthetic data
        plt.subplot(1, 2, 2)
        synthetic_data = model.sample(num_samples=1000)
        visualize_synthetic_data(train_datasets[dataset_size]['X'], synthetic_data)
        plt.title(f'Synthetic Data (dataset_size={dataset_size})')

        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()

        # Calculate MMD score
        mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
        print(f'MMD Score (dataset_size={dataset_size}): {mmd_value:.4f}')

        results.append((dataset_size, mmd_value))

    # Plot MMD scores
    plt.figure(figsize=(8, 5))
    dataset_sizes, mmd_values = zip(*results)
    plt.plot(dataset_sizes, mmd_values, marker='o')
    plt.title('MMD Scores for Different Dataset Sizes')
    plt.yscale('log')
    plt.xlabel('Dataset Size')
    plt.ylabel('MMD Score')
    plt.show()


# Different dataset sizes to try
print(f"For fixed number of blocks=10,  hidden_size=200, lr=0.0001, num_epochs=10")
dataset_sizes_to_try = [ 100, 200,300,400,500,600,700,800,900, 1000, 5000]
train_and_plot_for_different_dataset_sizes(dataset_sizes_to_try, train_loader, val_loader)
For fixed number of blocks=10,  hidden_size=200, lr=0.0001, num_epochs=10

Training for dataset_size=100
Training complete
MMD Score (dataset_size=100): 0.0202

Training for dataset_size=200
Training complete
MMD Score (dataset_size=200): 0.0222

Training for dataset_size=300
Training complete
MMD Score (dataset_size=300): 0.0111

Training for dataset_size=400
Training complete
MMD Score (dataset_size=400): 0.0030

Training for dataset_size=500
Training complete
MMD Score (dataset_size=500): 0.0083

Training for dataset_size=600
Training complete
MMD Score (dataset_size=600): 0.0037

Training for dataset_size=700
Training complete
MMD Score (dataset_size=700): 0.0042

Training for dataset_size=800
Training complete
MMD Score (dataset_size=800): 0.0046

Training for dataset_size=900
Training complete
MMD Score (dataset_size=900): 0.0033

Training for dataset_size=1000
Training complete
MMD Score (dataset_size=1000): 0.0052

Training for dataset_size=5000
Training complete
MMD Score (dataset_size=5000): 0.0015

Observation for different dataset sizes:¶

We did a few trial runs (of the cell above) for datasets of different sizes. We observed that in general, The quality of synthetic dataset so generated increases with increase in the size of the dataset.

1.3. Effect of Learning rate¶

In [ ]:
def train_and_plot_for_different_learning_rates(learning_rates, dataset_size=1000, block_size=10):
    results = []

    for lr in learning_rates:
        print(f"\nTraining for learning rate={lr}")

        # Instantiate the model
        model = RealNVP(input_size=2, hidden_size=200, blocks=block_size)  # Fix other parameters

        # Create data loader for the fixed dataset size
        data_considered = train_datasets[dataset_size]['X']
        train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True)

        # Train the model
        train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=lr, print_after=1)

        # Plot code distribution
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.xlim(-3, 3)
        plt.ylim(-3, 3)
        plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
        plt.title(f'Code Distribution (learning_rate={lr})')

        # Plot synthetic data
        plt.subplot(1, 2, 2)
        synthetic_data = model.sample(num_samples=1000)
        visualize_synthetic_data(train_datasets[dataset_size]['X'], synthetic_data)
        plt.title(f'Synthetic Data (learning_rate={lr})')

        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()

        # Calculate MMD score
        mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
        print(f'MMD Score (learning_rate={lr}): {mmd_value:.4f}')

        results.append((lr, mmd_value))

    # Plot MMD scores
    plt.figure(figsize=(8, 5))
    learning_rates, mmd_values = zip(*results)
    plt.plot(learning_rates, mmd_values, marker='o')
    plt.title('MMD Scores for Different Learning Rates')
    plt.xlabel('Learning Rate')
    plt.ylabel('MMD Score')
    plt.xscale('log')  # Use a logarithmic scale for better visualization of different orders of magnitude
    plt.yscale('log')
    plt.show()


# Different learning rates to try
print("For fixed number of blocks=10, hidden_size=200, dataset_size=1000, num_epochs=10")
learning_rates_to_try = [0.01,0.005,0.0005,0.0001,0.000005]
train_and_plot_for_different_learning_rates(learning_rates_to_try, dataset_size=1000, block_size=10)
For fixed number of blocks=10, hidden_size=200, dataset_size=1000, num_epochs=10

Training for learning rate=0.01
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
c:\Users\luke\OneDrive\Dokumente\UniHeidelberg\Master\Semester3\Generative Neural Networks\code\Exercise_3_GNN_for_science.ipynb Cell 23 line 5
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=53'>54</a> print("For fixed number of blocks=10, hidden_size=200, dataset_size=1000, num_epochs=10")
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=54'>55</a> learning_rates_to_try = [0.01,0.005,0.0005,0.0001,0.000005]
---> <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=55'>56</a> train_and_plot_for_different_learning_rates(learning_rates_to_try, dataset_size=1000, block_size=10)

c:\Users\luke\OneDrive\Dokumente\UniHeidelberg\Master\Semester3\Generative Neural Networks\code\Exercise_3_GNN_for_science.ipynb Cell 23 line 1
      <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=7'>8</a> model = RealNVP(input_size=2, hidden_size=200, blocks=block_size)  # Fix other parameters
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=9'>10</a> # Create data loader for the fixed dataset size
---> <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=10'>11</a> data_considered = train_datasets[dataset_size]['X']
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=11'>12</a> train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=12'>13</a> val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True)

KeyError: 1000

Observation:¶

We observed that the lower learning rates were suited best for the real NVP (INN) model. We found the learning rate of 1e-4 the most suitable for this case.

1.4. Effect of number of epochs¶

In [ ]:
def train_and_plot_for_different_epochs(epochs_list, dataset_size=1000, block_size=10, lr=0.0001):
    results = []

    for num_epochs in epochs_list:
        print(f"\nTraining for num_epochs={num_epochs}")

        # Instantiate the model
        model = RealNVP(input_size=2, hidden_size=200, blocks=block_size)  # Fix other parameters

        # Create data loader for the fixed dataset size
        data_considered = train_datasets[dataset_size]['X']
        train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True)

        # Train the model
        train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=num_epochs, lr=lr, print_after=2)

        # Plot code distribution
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.xlim(-3, 3)
        plt.ylim(-3, 3)
        plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
        plt.title(f'Code Distribution (num_epochs={num_epochs})')

        # Plot synthetic data
        plt.subplot(1, 2, 2)
        synthetic_data = model.sample(num_samples=1000)
        visualize_synthetic_data(train_datasets[dataset_size]['X'], synthetic_data)
        plt.title(f'Synthetic Data (num_epochs={num_epochs})')

        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()

        # Calculate MMD score
        mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
        print(f'MMD Score (num_epochs={num_epochs}): {mmd_value:.4f}')

        results.append((num_epochs, mmd_value))

    # Plot MMD scores
    plt.figure(figsize=(8, 5))
    num_epochs_values, mmd_values = zip(*results)
    plt.plot(num_epochs_values, mmd_values, marker='o')
    plt.title('MMD Scores for Different Numbers of Epochs')
    plt.xlabel('Number of Epochs')
    plt.ylabel('MMD Score')
    plt.show()


# Different numbers of epochs to try
print("For fixed number of blocks=10, hidden_size=200, dataset_size=5000, lr=0.0001")
epochs_to_try = [10,40,70,100]
train_and_plot_for_different_epochs(epochs_to_try, dataset_size=5000, block_size=10, lr=0.0001)
For fixed number of blocks=10, hidden_size=200, dataset_size=5000, lr=0.0001

Training for num_epochs=10
Epoch 2/10, Training Loss: -0.14358886231414297, Validation Loss: -0.2803282976742497
Epoch 4/10, Training Loss: -0.17533571784469215, Validation Loss: -0.4116998964801748
Epoch 6/10, Training Loss: -0.18580710415474394, Validation Loss: -0.4493791166138142
Epoch 8/10, Training Loss: -0.1960323669609021, Validation Loss: -0.4617179879482756
Epoch 10/10, Training Loss: -0.20351748117669063, Validation Loss: -0.48170990528578456
Training complete
MMD Score (num_epochs=10): 0.0073

Training for num_epochs=40
Epoch 2/40, Training Loss: -0.1378864387388934, Validation Loss: -0.3690203459973031
Epoch 4/40, Training Loss: -0.1580163512632928, Validation Loss: -0.38746916185668173
Epoch 6/40, Training Loss: -0.17779654987495053, Validation Loss: -0.43654367295985524
Epoch 8/40, Training Loss: -0.1889630897986618, Validation Loss: -0.4552901840590416
Epoch 10/40, Training Loss: -0.19047909550030123, Validation Loss: -0.42791273777789257
Epoch 12/40, Training Loss: -0.19354358423839915, Validation Loss: -0.4706975151883795
Epoch 14/40, Training Loss: -0.2010810967704112, Validation Loss: -0.46784974095669196
Epoch 16/40, Training Loss: -0.2049939930523661, Validation Loss: -0.44716153975496903
Epoch 18/40, Training Loss: -0.20869793099435893, Validation Loss: -0.49485309897585117
Epoch 20/40, Training Loss: -0.20657219410958616, Validation Loss: -0.4554901722263783
Epoch 22/40, Training Loss: -0.20321482946588235, Validation Loss: -0.46633801751948417
Epoch 24/40, Training Loss: -0.21119618913666768, Validation Loss: -0.4876342476048368
Epoch 26/40, Training Loss: -0.20632553188638253, Validation Loss: -0.46331152272351245
Epoch 28/40, Training Loss: -0.21113421246409417, Validation Loss: -0.4772615965376509
Epoch 30/40, Training Loss: -0.21129591796885838, Validation Loss: -0.4784514904022217
Epoch 32/40, Training Loss: -0.21280228827487338, Validation Loss: -0.4888522742276496
Epoch 34/40, Training Loss: -0.2113928150724281, Validation Loss: -0.4908938553739101
Epoch 36/40, Training Loss: -0.21354313719679008, Validation Loss: -0.5119861174137035
Epoch 38/40, Training Loss: -0.21345627887005156, Validation Loss: -0.45827306902154963
Epoch 40/40, Training Loss: -0.21106771359389478, Validation Loss: -0.4497500956058502
Training complete
MMD Score (num_epochs=40): 0.0046

Training for num_epochs=70
Epoch 2/70, Training Loss: -0.13076824862668715, Validation Loss: -0.37503495146619514
Epoch 4/70, Training Loss: -0.1627920740707354, Validation Loss: -0.40823768808486616
Epoch 6/70, Training Loss: -0.1742848001081835, Validation Loss: -0.4192498841501297
Epoch 8/70, Training Loss: -0.19004247720268638, Validation Loss: -0.4279984803275859
Epoch 10/70, Training Loss: -0.19786307911642573, Validation Loss: -0.46302134083940627
Epoch 12/70, Training Loss: -0.20096099491823805, Validation Loss: -0.42286763165859464
Epoch 14/70, Training Loss: -0.20947395664724436, Validation Loss: -0.3848408782418738
Epoch 16/70, Training Loss: -0.21476675871420992, Validation Loss: -0.47927849850756055
Epoch 18/70, Training Loss: -0.21008852278305726, Validation Loss: -0.5150847650588827
Epoch 20/70, Training Loss: -0.21329289078712463, Validation Loss: -0.5035497875923806
Epoch 22/70, Training Loss: -0.21433152136477557, Validation Loss: -0.4633709376162671
Epoch 24/70, Training Loss: -0.21784016591581432, Validation Loss: -0.4931467118415427
Epoch 26/70, Training Loss: -0.2165901618925008, Validation Loss: -0.5155834771217184
Epoch 28/70, Training Loss: -0.21100665737282145, Validation Loss: -0.51318791952539
Epoch 30/70, Training Loss: -0.21846945570273832, Validation Loss: -0.4833254655624958
Epoch 32/70, Training Loss: -0.21690888702869415, Validation Loss: -0.5067412986400279
Epoch 34/70, Training Loss: -0.22221683649854226, Validation Loss: -0.49461661437724497
Epoch 36/70, Training Loss: -0.21972130513326688, Validation Loss: -0.5188743472099304
Epoch 38/70, Training Loss: -0.2154702734117481, Validation Loss: -0.518978476524353
Epoch 40/70, Training Loss: -0.21986391347917644, Validation Loss: -0.5252069203143425
Epoch 42/70, Training Loss: -0.22182731194929642, Validation Loss: -0.5229760636674597
Epoch 44/70, Training Loss: -0.2206318766556003, Validation Loss: -0.5316517213557629
Epoch 46/70, Training Loss: -0.2191770705648444, Validation Loss: -0.5050297106834168
Epoch 48/70, Training Loss: -0.22091747705232012, Validation Loss: -0.5135451238206092
Epoch 50/70, Training Loss: -0.22272868650880726, Validation Loss: -0.507524533474699
Epoch 52/70, Training Loss: -0.22765119698914615, Validation Loss: -0.5047896504402161
Epoch 54/70, Training Loss: -0.21782343824478714, Validation Loss: -0.5143627763745633
Epoch 56/70, Training Loss: -0.22382738732478835, Validation Loss: -0.5082689720265409
Epoch 58/70, Training Loss: -0.2197471954436465, Validation Loss: -0.5090845556969338
Epoch 60/70, Training Loss: -0.2229800802401521, Validation Loss: -0.5346206350529448
Epoch 62/70, Training Loss: -0.22533906894651326, Validation Loss: -0.5416639390143942
Epoch 64/70, Training Loss: -0.2256680274890228, Validation Loss: -0.5050788458357466
Epoch 66/70, Training Loss: -0.22534454664723438, Validation Loss: -0.5375226907273556
Epoch 68/70, Training Loss: -0.2253061113709753, Validation Loss: -0.5076465005491008
Epoch 70/70, Training Loss: -0.22640712132508103, Validation Loss: -0.5100123501838522
Training complete
MMD Score (num_epochs=70): 0.0015

Training for num_epochs=100
Epoch 2/100, Training Loss: -0.13594088523902675, Validation Loss: -0.3640339044814414
Epoch 4/100, Training Loss: -0.1487430640072985, Validation Loss: -0.39018139021193726
Epoch 6/100, Training Loss: -0.1542976822873408, Validation Loss: -0.42238458230140363
Epoch 8/100, Training Loss: -0.16581505732187493, Validation Loss: -0.38754124382629673
Epoch 10/100, Training Loss: -0.17620162418831817, Validation Loss: -0.41299855043279365
Epoch 12/100, Training Loss: -0.1797038776524873, Validation Loss: -0.46067460800739046
Epoch 14/100, Training Loss: -0.1875358109193092, Validation Loss: -0.4512183618672351
Epoch 16/100, Training Loss: -0.19431041539223357, Validation Loss: -0.46909774491127504
Epoch 18/100, Training Loss: -0.19388091472739524, Validation Loss: -0.47596237412158476
Epoch 20/100, Training Loss: -0.19564713624475355, Validation Loss: -0.4955769181251526
Epoch 22/100, Training Loss: -0.19411445575004274, Validation Loss: -0.4284949873356109
Epoch 24/100, Training Loss: -0.20146982812068678, Validation Loss: -0.4502627342305285
Epoch 26/100, Training Loss: -0.2027834393422712, Validation Loss: -0.4436948429396812
Epoch 28/100, Training Loss: -0.20560864592817696, Validation Loss: -0.49413373876125255
Epoch 30/100, Training Loss: -0.20735824114897033, Validation Loss: -0.43310117055761055
Epoch 32/100, Training Loss: -0.20372144708579237, Validation Loss: -0.49495308323109405
Epoch 34/100, Training Loss: -0.2052411500364542, Validation Loss: -0.3276872003569882
Epoch 36/100, Training Loss: -0.20970047583634202, Validation Loss: -0.449070229175243
Epoch 38/100, Training Loss: -0.21172828359254212, Validation Loss: -0.4277269909990595
Epoch 40/100, Training Loss: -0.21106311015107415, Validation Loss: -0.45086967691462093
Epoch 42/100, Training Loss: -0.21024575511162932, Validation Loss: -0.4654041325792353
Epoch 44/100, Training Loss: -0.21036843491548843, Validation Loss: -0.47354090942981397
Epoch 46/100, Training Loss: -0.2124092212793502, Validation Loss: -0.5061322875479435
Epoch 48/100, Training Loss: -0.21252718266438353, Validation Loss: -0.48068189684380874
Epoch 50/100, Training Loss: -0.21141986224631018, Validation Loss: -0.49343108750404197
Epoch 52/100, Training Loss: -0.21551176594062285, Validation Loss: -0.49017775883065895
Epoch 54/100, Training Loss: -0.21384021821008486, Validation Loss: -0.45886730925833924
Epoch 56/100, Training Loss: -0.2167235734110529, Validation Loss: -0.49913877121945643
Epoch 58/100, Training Loss: -0.21804992294108325, Validation Loss: -0.4848848796905355
Epoch 60/100, Training Loss: -0.21321812539615415, Validation Loss: -0.4758980394677913
Epoch 62/100, Training Loss: -0.21669397191567855, Validation Loss: -0.4722903497675632
Epoch 64/100, Training Loss: -0.2187076737934893, Validation Loss: -0.4830910129115937
Epoch 66/100, Training Loss: -0.21772300817749718, Validation Loss: -0.5008368403353589
Epoch 68/100, Training Loss: -0.21606271212751216, Validation Loss: -0.5120836623171543
Epoch 70/100, Training Loss: -0.2175406933508136, Validation Loss: -0.48003573049890236
Epoch 72/100, Training Loss: -0.21688812923702327, Validation Loss: -0.5070201881388401
Epoch 74/100, Training Loss: -0.21162096383896742, Validation Loss: -0.5203415630979741
Epoch 76/100, Training Loss: -0.2169081981886517, Validation Loss: -0.483633760442125
Epoch 78/100, Training Loss: -0.22372531721537764, Validation Loss: -0.5064087939706254
Epoch 80/100, Training Loss: -0.21813252635977484, Validation Loss: -0.5143103402979831
Epoch 82/100, Training Loss: -0.2172615073621273, Validation Loss: -0.47639415841153326
Epoch 84/100, Training Loss: -0.21871006509119814, Validation Loss: -0.46699031727745177
Epoch 86/100, Training Loss: -0.2207058498678221, Validation Loss: -0.5056525043984677
Epoch 88/100, Training Loss: -0.22040676318786362, Validation Loss: -0.5213642380339034
Epoch 90/100, Training Loss: -0.21426767673004757, Validation Loss: -0.5214954779503194
Epoch 92/100, Training Loss: -0.21918462758714502, Validation Loss: -0.5133034390337924
Epoch 94/100, Training Loss: -0.22393759550018744, Validation Loss: -0.42689299164866246
Epoch 96/100, Training Loss: -0.22194279963997277, Validation Loss: -0.5066487827199571
Epoch 98/100, Training Loss: -0.2178570749407465, Validation Loss: -0.5179648266193715
Epoch 100/100, Training Loss: -0.21775149581530553, Validation Loss: -0.4457393101555236
Training complete
MMD Score (num_epochs=100): 0.0038

DONE (for 2 moons dataset atleast):~Tasks for the prob 1 of exercise 2 we need to do:~¶

Hyperparameters involved:

  1. Size of the training set
  2. Number of epochs
  3. Learning rate
  • Check the effect of these 3 hyperparameters on the model quality.
    1. Check for the quality of the code distribution (it should be indeed standard normal)
    2. Check the quality of the generated data: RealNVP should have a function RealNVP.sample(self,num_samples) that generates the requested number of synthetic points. REPORT the MMD between a testset and generated datapoints: to be mre specific show that visually better results correspond to the smaller MMD.

Tasks to do atm:¶

  1. check if the model is implemented correctly.
  2. MOST IMPORTANTLY: is the loss function implemented correctly? when I return the abs(loss) in the calculate_loss function, I see some training and improvement. IDK whether the loss function I have implemented is correct or not!

Question 2. Conditional INN¶

In [ ]:
### conditional coupling layer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import one_hot

class ConditionalCouplingLayer(nn.Module):
    def __init__(self, input_size, hidden_size, condition_size):
        """
        Initialize a ConditionalCouplingLayer.

        Args:
        - input_size (int): Total size of the input data.
        - hidden_size (int): Size of the hidden layers in the neural networks.
        - condition_size (int): Size of the condition vector (e.g., one-hot encoded label size).
        """
        super(ConditionalCouplingLayer, self).__init__()
        # Neural networks for the first half of the dimensions
        self.fc1 = nn.Linear(input_size // 2 + condition_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        # Translation coefficient
        self.fc3 = nn.Linear(hidden_size, input_size // 2)
        # Scaling coefficient
        self.fc4 = nn.Linear(hidden_size, input_size // 2)

    def forward(self, x, condition):
        """
        Forward pass through the ConditionalCouplingLayer.

        Args:
        - x (torch.Tensor): Input data.
        - condition (torch.Tensor): Condition vector.

        Returns:
        - y (torch.Tensor): Transformed data.
        - scaling_before_exp (torch.Tensor): Scaling coefficients before the exponential operation.
        """
        # Split the input into two halves
        x_a, x_b = x.chunk(2, dim=1)

        # Concatenate conditions to the first half
        x_a_concat = torch.cat([x_a, condition], dim=1)

        # Apply neural network to calculate coefficients
        h = F.relu(self.fc1(x_a_concat))
        h = F.relu(self.fc2(h))
        translation = self.fc3(h)
        scaling_before_exp = torch.tanh(self.fc4(h))
        scaling = torch.exp(scaling_before_exp)

        # Apply the affine transformation
        y_b = x_b * scaling + translation

        # Concatenate the transformed halves
        y = torch.cat([x_a, y_b], dim=1)
        return y, scaling_before_exp

    def backward(self, y, condition):
        """
        Backward pass through the ConditionalCouplingLayer.

        Args:
        - y (torch.Tensor): Transformed data.
        - condition (torch.Tensor): Condition vector.

        Returns:
        - x (torch.Tensor): Reconstructed original input.
        """
        # Split the input into two halves
        y_a, y_b = y.chunk(2, dim=1)

        # Concatenate conditions to the first half
        y_a_concat = torch.cat([y_a, condition], dim=1)

        # Apply neural network to calculate coefficients (reverse)
        h = F.relu(self.fc1(y_a_concat))
        h = F.relu(self.fc2(h))
        translation = self.fc3(h)
        scaling_before_exp = self.fc4(h)
        scaling = torch.exp(torch.tanh(scaling_before_exp))

        # Reverse the operations to reconstruct the original input
        x_a = y_a
        x_b = (y_b - translation) / scaling

        # Concatenate the reconstructed halves
        x = torch.cat([x_a, x_b], dim=1)
        return x
In [ ]:
### conditional real NVP class
class ConditionalRealNVP(nn.Module):
    def __init__(self, input_size, hidden_size, condition_size, blocks):
        """
        Initialize a ConditionalRealNVP model.

        Args:
        - input_size (int): Total size of the input data.
        - hidden_size (int): Size of the hidden layers in the neural networks.
        - condition_size (int): Size of the condition vector (e.g., one-hot encoded label size).
        - blocks (int): Number of coupling layers in the model.
        """
        super(ConditionalRealNVP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.condition_size = condition_size
        self.blocks = blocks

        # List of coupling layers
        self.coupling_layers = nn.ModuleList([
            ConditionalCouplingLayer(input_size, hidden_size, condition_size) for _ in range(blocks)
        ])

        # List to store orthonormal matrices
        self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]

        # List to store scaling_before_exp for each block
        self.scaling_before_exp_list = []

    def _get_orthonormal_matrix(self, size):
        """
        Generate a random orthonormal matrix.

        Args:
        - size (int): Size of the matrix.

        Returns:
        - q (torch.Tensor): Orthonormal matrix.
        """
        w = torch.randn(size, size)
        q, _ = torch.linalg.qr(w, 'reduced')
        return q

    def forward_realnvp(self, x, condition):
        """
        Forward pass through the ConditionalRealNVP model.

        Args:
        - x (torch.Tensor): Input data.
        - condition (torch.Tensor): Condition vector.

        Returns:
        - x (torch.Tensor): Transformed data.
        """
        scaling_before_exp_list = []
        for i in range(self.blocks):
            #print("x is:"); print(x)
            #print("shape of x is:"); print(x.shape)
            x = torch.matmul(x, self.orthonormal_matrices[i])
            x, scaling_before_exp = self.coupling_layers[i].forward(x, condition)
            scaling_before_exp_list.append(scaling_before_exp)

        self.scaling_before_exp_list = scaling_before_exp_list
        return x

    def decode(self, z, condition):
        """
        Reverse transformations to decode data.

        Args:
        - z (torch.Tensor): Transformed data.
        - condition (torch.Tensor): Condition vector.

        Returns:
        - z (torch.Tensor): Reconstructed original data.
        """
        for i in reversed(range(self.blocks)):
            z = self.coupling_layers[i].backward(z, condition)
            z = torch.matmul(z, self.orthonormal_matrices[i].t())
        return z

    def sample(self, num_samples=1000, conditions=None):
        """
        Generate synthetic samples.

        Args:
        - num_samples (int): Number of synthetic samples to generate.
        - conditions (torch.Tensor): Conditions for generating synthetic samples.

        Returns:
        - synthetic_samples (torch.Tensor): Synthetic samples.
        """
        with torch.no_grad():
            z = torch.randn(num_samples, self.input_size)
            synthetic_samples = self.decode(z, conditions)
        return synthetic_samples
In [ ]:
### training_the_conditional_nvp model

import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def train_and_validate_conditional_nvp(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
    """
    Train the ConditionalRealNVP model and evaluate on a validation dataset.

    Args:
    - model (ConditionalRealNVP): The ConditionalRealNVP model to be trained.
    - train_loader (DataLoader): DataLoader for the training dataset.
    - val_loader (DataLoader): DataLoader for the validation dataset.
    - num_epochs (int): Number of training epochs.
    - lr (float): Learning rate for the optimizer.
    - print_after (int): Number of epochs after which to print the training and validation loss.

    Returns:
    - train_losses (list): List of training losses for each epoch.
    - val_losses (list): List of validation losses for each epoch.
    """

    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []  # List to store training losses
    val_losses = []    # List to store validation losses

    for epoch in range(num_epochs):
        total_train_loss = 0.0

        # Training phase
        model.train()  # Set the model to training mode
        for data, labels in train_loader:
            inputs = data
            conditions = one_hot(labels, num_classes=model.condition_size).float()

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass (encoding)
            encoded = model.forward_realnvp(inputs, conditions)

            # Loss calculation
            train_loss = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))

            # Backward pass (gradient computation)
            train_loss.backward()

            # Clip gradients to prevent exploding gradients
            clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Update weights
            optimizer.step()

            total_train_loss += train_loss.item()

        # Average training loss for the epoch
        average_train_loss = total_train_loss / len(train_loader)

        # Validation phase
        if val_loader is not None:
            model.eval()  # Set the model to evaluation mode
            total_val_loss = 0.0
            with torch.no_grad():
                for val_data, val_labels in val_loader:
                    val_inputs = val_data
                    val_conditions = one_hot(val_labels, num_classes=model.condition_size).float()

                    # Forward pass (encoding) for validation
                    val_encoded = model.forward_realnvp(val_inputs, val_conditions)

                    # Loss calculation for validation
                    val_loss = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))

                    total_val_loss += val_loss.item()

            # Average validation loss for the epoch
            average_val_loss = total_val_loss / len(val_loader)

            # Print training and validation losses together
            if (epoch + 1) % print_after == 0:
                print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss}, Validation Loss: {average_val_loss}")

            # Append losses to the lists
            train_losses.append(average_train_loss)
            val_losses.append(average_val_loss)

        # Set the model back to training mode
        model.train()

    print("Training complete")

    return train_losses, val_losses
In [ ]:
### Create the dataset and dataloaders for the conditional NVP model
dataset_sizes = [ 100, 200,300,400,500,600,700,800,900, 1000, 5000]

# Generate datasets of varying sizes
train_datasets = {}
val_datasets = {}
datasets = {}

for size in dataset_sizes:
    X, y = make_moons(n_samples=size, noise=0.1)
    datasets[size] = {'data': X, 'labels': y}### label imply to which moon does it belong to
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    train_datasets[size] = {'data': torch.FloatTensor(X_train), 'label': y_train}
    val_datasets[size] = {'data': torch.FloatTensor(X_test), 'label': y_test}
In [ ]:
#### data for the two-moons model
from torch.utils.data import TensorDataset, DataLoader

# Define a custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y

task 1. train the conditional INN¶

In [ ]:
# Define model parameters
input_size = 2
hidden_size = 200
condition_size = 2
blocks = 10

# Initialize the model
conditional_inn_model = ConditionalRealNVP(input_size, hidden_size, condition_size, blocks)

# Define hyperparameters
num_epochs = 10
lr = 0.0001

# Create datasets
dataset_size=1000
train_dataset = CustomDataset(train_datasets[dataset_size]['data'], train_datasets[dataset_size]['label'])
val_dataset = CustomDataset(val_datasets[dataset_size]['data'], val_datasets[dataset_size]['label'])

# Define batch size
batch_size = 32

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Task 1: Train the Conditional INN
train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader,
                                                         num_epochs=num_epochs, lr=lr, print_after=1)

# # plotting the loss
# plot_losses(train_losses[3:], val_losses[3:], want_log_scale=0)
# plt.show()

task 2. evaluate the conditonal model p(x|y)¶

In [ ]:
# Choose a label for evaluation (e.g., label 0)
eval_condition = torch.tensor([[1, 0]])  # One-hot encoding for label 0

# Repeat the condition vector for each sample
eval_condition = eval_condition.repeat(1000, 1)

with torch.no_grad():
    # Generate synthetic samples for the chosen label
    synthetic_samples_label_0 = conditional_inn_model.sample(num_samples=1000, conditions=eval_condition)

# # Example usage:
# plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
# plt.show()

### plot the synthetic data and the original data
visualize_synthetic_data(train_datasets[1000]['data'], synthetic_samples_label_0)
plt.show()

task 3. Merge synthetic data from all labels and compare the marginal distributions.¶

In [ ]:
# Generate synthetic samples for all labels
# Generate synthetic samples for all labels
conditions_all_labels = torch.eye(condition_size)  # Assuming one-hot encoding

# Repeat the condition vector for each sample
conditions_all_labels = conditions_all_labels.repeat(1000, 1)

with torch.no_grad():
    synthetic_samples_all_labels = conditional_inn_model.sample(num_samples=2000, conditions=conditions_all_labels)

visualize_synthetic_data(train_datasets[1000]['data'], synthetic_samples_all_labels)
plt.show()

3. Higher-dimensional data with an INN¶

3.1 Loading new data and Hyperparameter tuning¶

First we load the dataset and create different sizes

In [ ]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import torch

# Load the digits dataset
digits = load_digits()

# Define the dataset percentages
dataset_percentages = [0.1, 0.5, 1]
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2)
val_dataset = {'X': torch.FloatTensor(X_test), 'y': y_test}

# Generate datasets of varying sizes
train_datasets = {}
for percentage in dataset_percentages:
    # Take a subset of the digits dataset based on the desired size
    size = int(len(y_train)*percentage)
    X, y = X_train[:size], y_train[:size]
    train_datasets[percentage] = {'X': torch.FloatTensor(X), 'y': y}
In [ ]:
def plot_code_distribution(model, test_loader):
    """
    Plot the code distribution obtained by applying the trained RealNVP model to a test dataset.

    Args:
    - model (RealNVP): Trained RealNVP model.
    - test_loader (DataLoader): DataLoader for the test dataset.
    - num_samples (int): Number of samples to visualize.

    Returns:
    None (displays the plot).
    """
    model.eval()  # Set the model to evaluation mode
    fig, axs = plt.subplots(2, 5, figsize=(20, 7))
    with torch.no_grad():
        # Concatenate multiple batches to obtain more samples
        test_samples = torch.cat([batch for batch in test_loader], dim=0)
        # Assuming your model has an `encode` method
        code_samples = model.encode(test_samples)

        # Convert PyTorch tensor to numpy array
        code_np = code_samples.numpy()
        dim_1 = 0
        dim_2 = 1
        for i in range(2):
            for j in range(5):
                # Scatter plot of code distribution
                axs[i,j].scatter(code_np[:, dim_1], code_np[:, dim_2], label='Code Distribution', alpha=0.5)
                axs[i,j].set_xlabel(f"Code Dimension {dim_1}")
                axs[i,j].set_ylabel(f"Code Dimension {dim_2}")
                axs[i,j].set_title(f'Code Distribution: {dim_2}')
                dim_1 += 1
                dim_2 += 1
        plt.tight_layout()
        plt.show()
In [ ]:
def visualize_synthetic_data(synthetic_data, title=""):
    """
    Scatter plot to visualize the original and synthetic data in 2D.

    Args:
    - synthetic_data (torch.Tensor): Synthetic data.

    Returns:
    - None: Displays the scatter plot.
    """
    fig, axs = plt.subplots(2, 5, figsize=(20, 7))
    # Ensure both original and synthetic data are converted to numpy arrays
    with torch.no_grad():
        # Convert PyTorch tensors to numpy arrays
        synthetic_np = synthetic_data.numpy()
        count = 0
        for i in range(2):
            for j in range(5):
                axs[i,j].imshow(synthetic_np[count].reshape(8, 8), cmap='gray')
                axs[i,j].set_title(f'Synthetic Image: {count}')
                count += 1
        # Scatter plot of original and synthetic data
        fig.suptitle(title)
        plt.show()

Now we test if our network works with the new dataset

In [ ]:
input_size = 64
hidden_size = 200
blocks = 10

print_after=1

# initialize dataloader
dataset_percentage = 0.5
batch_size=32
data_considered=train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=batch_size, shuffle=True)
val_loader= torch.utils.data.DataLoader(val_dataset['X'], batch_size=batch_size, shuffle=True)


# instantiate the model
model= RealNVP(input_size=input_size, hidden_size= hidden_size, blocks=blocks)

## train the model
train_losses, val_losses= train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1)

# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()

# Example usage:
plot_code_distribution(model=model, test_loader=val_loader)
plt.show()

### plot the synthetic data and the original data
synthetic_data=model.sample(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data)
plt.show()
Epoch 1/10, Training Loss: 19308.744353252907, Validation Loss: 308.0241241455078
Epoch 2/10, Training Loss: 151.0023258043372, Validation Loss: 260.6484127044678
Epoch 3/10, Training Loss: 129.43456749294114, Validation Loss: 241.6455866495768
Epoch 4/10, Training Loss: 115.87873409105384, Validation Loss: 231.48429171244302
Epoch 5/10, Training Loss: 106.71935371730639, Validation Loss: 228.04992612202963
Epoch 6/10, Training Loss: 100.71591966048531, Validation Loss: 232.55555311838785
Epoch 7/10, Training Loss: 95.37316413547681, Validation Loss: 227.84586747487387
Epoch 8/10, Training Loss: 91.04394149780273, Validation Loss: 230.0734135309855
Epoch 9/10, Training Loss: 87.94942756321119, Validation Loss: 229.5797259012858
Epoch 10/10, Training Loss: 83.55091410097869, Validation Loss: 240.51764233907065
Training complete

The code distribution looks quite resonable and gaussian distributed. The synthesized data is beginning to look like digits, but still quite noisy. Next we try to find the optimal hyperparameter.

In [ ]:
learning_rates = [0.01,0.005,0.0005,0.0001,0.000005]
dataset_percentage = 1.0
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)

for lr in learning_rates:
    print(f"\nTraining for learning rate={lr}")

    # Instantiate the model
    model = RealNVP(input_size=input_size, hidden_size=hidden_size, blocks=blocks)

    # Train the model
    train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=30, lr=lr, print_after=1)

    # plotting the loss
    plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
    plt.show()

    # Example usage:
    plot_code_distribution(model=model, test_loader=val_loader)
    plt.show()

    ### plot the synthetic data
    synthetic_data=model.sample(num_samples=len(data_considered))
    visualize_synthetic_data(synthetic_data)
    
Training for learning rate=0.01
Epoch 1/30, Training Loss: 8009.836741299099, Validation Loss: 288.0275885264079
Epoch 2/30, Training Loss: 76.08619503445095, Validation Loss: 256.4709021250407
Epoch 3/30, Training Loss: 68.7188730875651, Validation Loss: 241.0039800008138
Epoch 4/30, Training Loss: 65.80661010742188, Validation Loss: 232.09345626831055
Epoch 5/30, Training Loss: 63.3587641398112, Validation Loss: 228.34066772460938
Epoch 6/30, Training Loss: 61.78260345458985, Validation Loss: 224.1435661315918
Epoch 7/30, Training Loss: 61.00211885240343, Validation Loss: 221.20492553710938
Epoch 8/30, Training Loss: 59.872193315294055, Validation Loss: 222.45267963409424
Epoch 9/30, Training Loss: 58.49332021077474, Validation Loss: 216.29072443644205
Epoch 10/30, Training Loss: 57.88600684271918, Validation Loss: 214.59373696645102
Epoch 11/30, Training Loss: 57.48555611504449, Validation Loss: 212.99246056874594
Epoch 12/30, Training Loss: 57.16072599622938, Validation Loss: 217.7744598388672
Epoch 13/30, Training Loss: 56.830532752143014, Validation Loss: 210.1707499821981
Epoch 14/30, Training Loss: 56.36947267320421, Validation Loss: 211.24468994140625
Epoch 15/30, Training Loss: 55.42196782430013, Validation Loss: 210.97901821136475
Epoch 16/30, Training Loss: 55.75337677001953, Validation Loss: 208.09878635406494
Epoch 17/30, Training Loss: 55.02196782430013, Validation Loss: 210.40148131052652
Epoch 18/30, Training Loss: 54.59654473198785, Validation Loss: 205.72683238983154
Epoch 19/30, Training Loss: 54.583910878499346, Validation Loss: 208.6036809285482
Epoch 20/30, Training Loss: 54.372396850585936, Validation Loss: 206.71865940093994
Epoch 21/30, Training Loss: 53.65991999308268, Validation Loss: 204.12037054697672
Epoch 22/30, Training Loss: 53.65253762139214, Validation Loss: 208.5793809890747
Epoch 23/30, Training Loss: 53.74359368218316, Validation Loss: 211.7449827194214
Epoch 24/30, Training Loss: 53.684653727213544, Validation Loss: 207.8568785985311
Epoch 25/30, Training Loss: 53.2072132534451, Validation Loss: 207.94135189056396
Epoch 26/30, Training Loss: 53.41186709933811, Validation Loss: 209.1408322652181
Epoch 27/30, Training Loss: 52.97333077324761, Validation Loss: 204.48490047454834
Epoch 28/30, Training Loss: 52.960265689425995, Validation Loss: 203.9521099726359
Epoch 29/30, Training Loss: 52.629005432128906, Validation Loss: 206.9331143697103
Epoch 30/30, Training Loss: 53.477005343967015, Validation Loss: 205.88944085439047
Training complete
Training for learning rate=0.005
Epoch 1/30, Training Loss: 11817.525778198242, Validation Loss: 270.10508664449054
Epoch 2/30, Training Loss: 69.870982615153, Validation Loss: 232.40014362335205
Epoch 3/30, Training Loss: 62.34513236151801, Validation Loss: 220.9430373509725
Epoch 4/30, Training Loss: 58.622605387369795, Validation Loss: 211.92736530303955
Epoch 5/30, Training Loss: 55.77464650472005, Validation Loss: 207.57684199015299
Epoch 6/30, Training Loss: 53.64491526285807, Validation Loss: 202.77793534596762
Epoch 7/30, Training Loss: 51.990141211615665, Validation Loss: 199.2128356297811
Epoch 8/30, Training Loss: 50.460140482584634, Validation Loss: 198.17747497558594
Epoch 9/30, Training Loss: 50.364815860324434, Validation Loss: 198.5120356877645
Epoch 10/30, Training Loss: 49.0253177218967, Validation Loss: 196.53202692667642
Epoch 11/30, Training Loss: 47.818701171875, Validation Loss: 194.4502503077189
Epoch 12/30, Training Loss: 47.43562757703993, Validation Loss: 191.5679677327474
Epoch 13/30, Training Loss: 46.55343619452582, Validation Loss: 197.71495151519775
Epoch 14/30, Training Loss: 45.96852815416124, Validation Loss: 191.3148282368978
Epoch 15/30, Training Loss: 45.09313100179036, Validation Loss: 195.80774847666422
Epoch 16/30, Training Loss: 45.09204661051432, Validation Loss: 192.35662587483725
Epoch 17/30, Training Loss: 44.720115661621094, Validation Loss: 191.43877792358398
Epoch 18/30, Training Loss: 44.07941326565213, Validation Loss: 193.19898001352945
Epoch 19/30, Training Loss: 44.19743118286133, Validation Loss: 190.72268931070963
Epoch 20/30, Training Loss: 43.4016234503852, Validation Loss: 185.80727926890054
Epoch 21/30, Training Loss: 43.52100160386827, Validation Loss: 188.30844116210938
Epoch 22/30, Training Loss: 42.71001519097222, Validation Loss: 192.54564380645752
Epoch 23/30, Training Loss: 42.025776926676436, Validation Loss: 192.11276976267496
Epoch 24/30, Training Loss: 42.15523783365885, Validation Loss: 189.09472274780273
Epoch 25/30, Training Loss: 41.77534535725911, Validation Loss: 191.74198309580484
Epoch 26/30, Training Loss: 41.31961042616103, Validation Loss: 194.46389230092367
Epoch 27/30, Training Loss: 40.85063256157769, Validation Loss: 195.73884105682373
Epoch 28/30, Training Loss: 40.96485850016276, Validation Loss: 193.3294941584269
Epoch 29/30, Training Loss: 41.45541627671984, Validation Loss: 194.46637217203775
Epoch 30/30, Training Loss: 40.55387997097439, Validation Loss: 190.83805497487387
Training complete
Training for learning rate=0.0005
Epoch 1/30, Training Loss: 3421.306739637587, Validation Loss: 288.54217465718585
Epoch 2/30, Training Loss: 73.53326305813259, Validation Loss: 245.20367558797201
Epoch 3/30, Training Loss: 64.15791592068142, Validation Loss: 227.901712735494
Epoch 4/30, Training Loss: 58.57548853556315, Validation Loss: 217.85050106048584
Epoch 5/30, Training Loss: 55.11601825290256, Validation Loss: 212.453226407369
Epoch 6/30, Training Loss: 52.042591603597, Validation Loss: 209.29950936635336
Epoch 7/30, Training Loss: 49.50460010104709, Validation Loss: 204.23781808217367
Epoch 8/30, Training Loss: 47.800843217637805, Validation Loss: 207.67214838663736
Epoch 9/30, Training Loss: 46.12979066636827, Validation Loss: 207.23672898610434
Epoch 10/30, Training Loss: 44.459814368353946, Validation Loss: 202.5842374165853
Epoch 11/30, Training Loss: 43.03138614230686, Validation Loss: 207.65548356374106
Epoch 12/30, Training Loss: 41.74035085042318, Validation Loss: 208.69722620646158
Epoch 13/30, Training Loss: 40.926596577962236, Validation Loss: 213.01070054372153
Epoch 14/30, Training Loss: 39.39276572333442, Validation Loss: 210.76309076944986
Epoch 15/30, Training Loss: 38.11468183729384, Validation Loss: 221.10226694742838
Epoch 16/30, Training Loss: 37.34300376044379, Validation Loss: 216.53514099121094
Epoch 17/30, Training Loss: 36.546615261501735, Validation Loss: 226.68182563781738
Epoch 18/30, Training Loss: 36.00442470974392, Validation Loss: 228.22372436523438
Epoch 19/30, Training Loss: 34.968190087212456, Validation Loss: 227.42056020100912
Epoch 20/30, Training Loss: 34.6193118625217, Validation Loss: 221.6729122797648
Epoch 21/30, Training Loss: 34.06460901896159, Validation Loss: 240.29690742492676
Epoch 22/30, Training Loss: 33.063655853271484, Validation Loss: 242.19485092163086
Epoch 23/30, Training Loss: 32.59690784878201, Validation Loss: 249.20933310190836
Epoch 24/30, Training Loss: 31.83392054239909, Validation Loss: 248.67272027333578
Epoch 25/30, Training Loss: 31.582435353597006, Validation Loss: 256.09630997975665
Epoch 26/30, Training Loss: 30.775786675347224, Validation Loss: 266.1253210703532
Epoch 27/30, Training Loss: 30.550120205349394, Validation Loss: 252.3609733581543
Epoch 28/30, Training Loss: 30.08771603902181, Validation Loss: 256.7566725413005
Epoch 29/30, Training Loss: 29.445677778455945, Validation Loss: 263.6085141499837
Epoch 30/30, Training Loss: 28.81046553717719, Validation Loss: 275.15847905476886
Training complete
Training for learning rate=0.0001
Epoch 1/30, Training Loss: 50113.41761644151, Validation Loss: 436.48069826761883
Epoch 2/30, Training Loss: 101.97941606309679, Validation Loss: 320.2299327850342
Epoch 3/30, Training Loss: 85.60570593939887, Validation Loss: 290.40064366658527
Epoch 4/30, Training Loss: 78.33708835177951, Validation Loss: 272.7111422220866
Epoch 5/30, Training Loss: 73.4560770670573, Validation Loss: 260.9111525217692
Epoch 6/30, Training Loss: 69.79727257622613, Validation Loss: 252.03315226236978
Epoch 7/30, Training Loss: 66.88374820285374, Validation Loss: 245.57611910502115
Epoch 8/30, Training Loss: 64.33567445543078, Validation Loss: 239.41691970825195
Epoch 9/30, Training Loss: 62.18655675252278, Validation Loss: 234.46007283528647
Epoch 10/30, Training Loss: 60.42937969631619, Validation Loss: 231.49946689605713
Epoch 11/30, Training Loss: 58.72264811197917, Validation Loss: 229.10130310058594
Epoch 12/30, Training Loss: 57.17389127943251, Validation Loss: 226.6750087738037
Epoch 13/30, Training Loss: 55.773566012912326, Validation Loss: 225.00317096710205
Epoch 14/30, Training Loss: 54.53753424750434, Validation Loss: 222.3171361287435
Epoch 15/30, Training Loss: 53.36913740370009, Validation Loss: 222.87827587127686
Epoch 16/30, Training Loss: 52.27295362684462, Validation Loss: 223.05616505940756
Epoch 17/30, Training Loss: 51.29743626912435, Validation Loss: 222.3541399637858
Epoch 18/30, Training Loss: 50.20012520684136, Validation Loss: 219.63851642608643
Epoch 19/30, Training Loss: 49.292029995388454, Validation Loss: 223.75088246663412
Epoch 20/30, Training Loss: 48.423710378011066, Validation Loss: 225.47177600860596
Epoch 21/30, Training Loss: 47.65950080023872, Validation Loss: 224.14251295725504
Epoch 22/30, Training Loss: 46.83405244615343, Validation Loss: 224.9314339955648
Epoch 23/30, Training Loss: 46.0850218878852, Validation Loss: 228.0197811126709
Epoch 24/30, Training Loss: 45.30505820380317, Validation Loss: 230.37444273630777
Epoch 25/30, Training Loss: 44.65458026462131, Validation Loss: 232.04700247446695
Epoch 26/30, Training Loss: 43.974732123480905, Validation Loss: 237.41689682006836
Epoch 27/30, Training Loss: 43.24482091267904, Validation Loss: 234.12914307912192
Epoch 28/30, Training Loss: 42.66191660563151, Validation Loss: 237.63699022928873
Epoch 29/30, Training Loss: 41.99319805569119, Validation Loss: 241.9967892964681
Epoch 30/30, Training Loss: 41.384490712483725, Validation Loss: 248.90463892618814
Training complete
Training for learning rate=5e-06
Epoch 1/30, Training Loss: 209076.60121527777, Validation Loss: 405062.1399739583
Epoch 2/30, Training Loss: 69862.05911458333, Validation Loss: 127279.57747395833
Epoch 3/30, Training Loss: 22368.46755642361, Validation Loss: 39132.588623046875
Epoch 4/30, Training Loss: 7067.427842881944, Validation Loss: 13500.332214355469
Epoch 5/30, Training Loss: 2560.1897352430556, Validation Loss: 5463.914642333984
Epoch 6/30, Training Loss: 1113.8959011501736, Validation Loss: 2742.546442667643
Epoch 7/30, Training Loss: 608.1996073404948, Validation Loss: 1645.3327814737956
Epoch 8/30, Training Loss: 386.06957058376736, Validation Loss: 1121.6913019816081
Epoch 9/30, Training Loss: 277.5015930175781, Validation Loss: 843.8144543965658
Epoch 10/30, Training Loss: 216.93083055284288, Validation Loss: 683.6320292154948
Epoch 11/30, Training Loss: 180.46121758355034, Validation Loss: 585.8069547017416
Epoch 12/30, Training Loss: 157.36595323350696, Validation Loss: 522.2299512227377
Epoch 13/30, Training Loss: 142.15333811442056, Validation Loss: 478.9925130208333
Epoch 14/30, Training Loss: 131.64742160373265, Validation Loss: 448.8523184458415
Epoch 15/30, Training Loss: 124.23250885009766, Validation Loss: 426.7560176849365
Epoch 16/30, Training Loss: 118.68551194932726, Validation Loss: 410.0090611775716
Epoch 17/30, Training Loss: 114.39310455322266, Validation Loss: 396.76917203267413
Epoch 18/30, Training Loss: 110.9672132703993, Validation Loss: 386.13282267252606
Epoch 19/30, Training Loss: 108.15684068467883, Validation Loss: 377.21822293599445
Epoch 20/30, Training Loss: 105.78033989800348, Validation Loss: 369.6262327829997
Epoch 21/30, Training Loss: 103.73234710693359, Validation Loss: 363.0723139444987
Epoch 22/30, Training Loss: 101.93745642768012, Validation Loss: 357.2292785644531
Epoch 23/30, Training Loss: 100.34205101860894, Validation Loss: 352.077735265096
Epoch 24/30, Training Loss: 98.9117190890842, Validation Loss: 347.46399943033856
Epoch 25/30, Training Loss: 97.61323394775391, Validation Loss: 343.2505931854248
Epoch 26/30, Training Loss: 96.42156558566623, Validation Loss: 339.3616517384847
Epoch 27/30, Training Loss: 95.32410888671875, Validation Loss: 335.8017037709554
Epoch 28/30, Training Loss: 94.30238511827257, Validation Loss: 332.52455139160156
Epoch 29/30, Training Loss: 93.36457349989149, Validation Loss: 329.5176232655843
Epoch 30/30, Training Loss: 92.4872580634223, Validation Loss: 326.7073853810628
Training complete

We see, that the best results for the validation loss and the generated images are with a learning rate of lr= 0.005. Since we already see, that the algorithms starts to overfit at around 19 epochs, we wont need to analyze the epochs count and the dataset size: We already found the best epoch count and due to the occuring overfitting it would not make sense to further test smaller dataset sizes

In [ ]:
hidden_sizes = [100, 200, 400]
blocks = [2, 5, 10]
input_size = 64

dataset_percentage = 1.0
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)

for hidden_size in hidden_sizes:
    for block in blocks:
        print(f"\nTraining for hidden_size={hidden_size}, blocks = {block}")

        # Instantiate the model
        model = RealNVP(input_size=input_size, hidden_size=hidden_size, blocks=block)

        # Train the model
        train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=30, lr=0.005, print_after=1)

        # plotting the loss
        plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
        plt.show()

        # Example usage:
        plot_code_distribution(model=model, test_loader=val_loader)
        plt.show()

        ### plot the synthetic data
        synthetic_data=model.sample(num_samples=len(data_considered))
        visualize_synthetic_data(synthetic_data)
        
Training for hidden_size=100, blocks = 2
Epoch 1/30, Training Loss: 294.7230936686198, Validation Loss: 571.6155484517416
Epoch 2/30, Training Loss: 141.9142296685113, Validation Loss: 468.15574010213214
Epoch 3/30, Training Loss: 124.05112999810113, Validation Loss: 420.20416831970215
Epoch 4/30, Training Loss: 114.67127126057943, Validation Loss: 405.47664070129395
Epoch 5/30, Training Loss: 112.06046108669705, Validation Loss: 401.04412778218585
Epoch 6/30, Training Loss: 109.06603291829427, Validation Loss: 385.57850710550946
Epoch 7/30, Training Loss: 104.84836205376519, Validation Loss: 383.4814968109131
Epoch 8/30, Training Loss: 102.67653469509548, Validation Loss: 399.0242977142334
Epoch 9/30, Training Loss: 102.31197408040364, Validation Loss: 360.4715805053711
Epoch 10/30, Training Loss: 99.84381306966146, Validation Loss: 357.0197696685791
Epoch 11/30, Training Loss: 99.02914411756727, Validation Loss: 355.18566576639813
Epoch 12/30, Training Loss: 97.31855146620009, Validation Loss: 351.6528517405192
Epoch 13/30, Training Loss: 96.65306871202257, Validation Loss: 358.52728335062665
Epoch 14/30, Training Loss: 95.8873289320204, Validation Loss: 358.1687469482422
Epoch 15/30, Training Loss: 94.39329071044922, Validation Loss: 355.81766446431476
Epoch 16/30, Training Loss: 94.19340837266711, Validation Loss: 337.2964045206706
Epoch 17/30, Training Loss: 92.56639607747395, Validation Loss: 338.4132130940755
Epoch 18/30, Training Loss: 92.66181182861328, Validation Loss: 327.18345133463544
Epoch 19/30, Training Loss: 92.0412845187717, Validation Loss: 331.57775179545087
Epoch 20/30, Training Loss: 90.90107608371311, Validation Loss: 334.78138478597003
Epoch 21/30, Training Loss: 89.76652086046008, Validation Loss: 332.50303332010907
Epoch 22/30, Training Loss: 89.18487158881294, Validation Loss: 332.22913614908856
Epoch 23/30, Training Loss: 89.43389723036024, Validation Loss: 327.5421390533447
Epoch 24/30, Training Loss: 88.90379621717665, Validation Loss: 328.9135939280192
Epoch 25/30, Training Loss: 89.27030453152126, Validation Loss: 313.00710614522296
Epoch 26/30, Training Loss: 88.07124938964844, Validation Loss: 307.8485215504964
Epoch 27/30, Training Loss: 86.89611290825738, Validation Loss: 308.5803909301758
Epoch 28/30, Training Loss: 85.90565677218967, Validation Loss: 328.0444863637288
Epoch 29/30, Training Loss: 88.49742211235895, Validation Loss: 319.33860270182294
Epoch 30/30, Training Loss: 88.56594645182291, Validation Loss: 309.60890515645343
Training complete
Training for hidden_size=100, blocks = 5
Epoch 1/30, Training Loss: 265.01097886827256, Validation Loss: 301.99928347269696
Epoch 2/30, Training Loss: 77.97027062310113, Validation Loss: 256.44384956359863
Epoch 3/30, Training Loss: 69.68282352023654, Validation Loss: 241.7223745981852
Epoch 4/30, Training Loss: 65.50670505099826, Validation Loss: 230.7071574529012
Epoch 5/30, Training Loss: 62.328228335910374, Validation Loss: 224.24456691741943
Epoch 6/30, Training Loss: 60.30294715033637, Validation Loss: 218.9724578857422
Epoch 7/30, Training Loss: 58.39633288913303, Validation Loss: 218.47990926106772
Epoch 8/30, Training Loss: 56.874779001871744, Validation Loss: 212.0576960245768
Epoch 9/30, Training Loss: 56.036871846516924, Validation Loss: 210.8950433731079
Epoch 10/30, Training Loss: 55.106080034044055, Validation Loss: 207.61690107981363
Epoch 11/30, Training Loss: 54.76850458780925, Validation Loss: 213.44031524658203
Epoch 12/30, Training Loss: 53.71915673149957, Validation Loss: 210.15140438079834
Epoch 13/30, Training Loss: 52.99325018988715, Validation Loss: 203.27078851064047
Epoch 14/30, Training Loss: 52.25531472100152, Validation Loss: 202.80283133188883
Epoch 15/30, Training Loss: 51.514439307318796, Validation Loss: 200.92092609405518
Epoch 16/30, Training Loss: 51.71739222208659, Validation Loss: 199.66056442260742
Epoch 17/30, Training Loss: 50.8276117960612, Validation Loss: 200.95742066701254
Epoch 18/30, Training Loss: 50.33599675496419, Validation Loss: 203.53884760538736
Epoch 19/30, Training Loss: 49.800680796305336, Validation Loss: 203.55471007029215
Epoch 20/30, Training Loss: 49.49485787285699, Validation Loss: 200.13616148630777
Epoch 21/30, Training Loss: 48.99873826768663, Validation Loss: 198.87054856618246
Epoch 22/30, Training Loss: 48.645048183865015, Validation Loss: 198.45662053426108
Epoch 23/30, Training Loss: 48.5343746609158, Validation Loss: 194.56494808197021
Epoch 24/30, Training Loss: 47.689731174045136, Validation Loss: 197.5397570927938
Epoch 25/30, Training Loss: 48.05319391886393, Validation Loss: 197.968976020813
Epoch 26/30, Training Loss: 47.807348039415146, Validation Loss: 198.26151434580484
Epoch 27/30, Training Loss: 47.789947001139325, Validation Loss: 195.22731018066406
Epoch 28/30, Training Loss: 47.258505249023436, Validation Loss: 198.5635643005371
Epoch 29/30, Training Loss: 47.24866765340169, Validation Loss: 197.3188044230143
Epoch 30/30, Training Loss: 47.136260477701825, Validation Loss: 197.86065769195557
Training complete
Training for hidden_size=100, blocks = 10
Epoch 1/30, Training Loss: 9436.184055074056, Validation Loss: 263.33286539713544
Epoch 2/30, Training Loss: 67.61033986409505, Validation Loss: 228.6816488901774
Epoch 3/30, Training Loss: 60.80630026923286, Validation Loss: 215.70353666941324
Epoch 4/30, Training Loss: 57.06355929904514, Validation Loss: 208.36373488108316
Epoch 5/30, Training Loss: 54.60490230984158, Validation Loss: 203.7723045349121
Epoch 6/30, Training Loss: 52.81721649169922, Validation Loss: 198.82628377278647
Epoch 7/30, Training Loss: 51.284493425157336, Validation Loss: 196.23615233103433
Epoch 8/30, Training Loss: 50.3772331237793, Validation Loss: 195.4634517033895
Epoch 9/30, Training Loss: 49.036874050564236, Validation Loss: 190.6725641886393
Epoch 10/30, Training Loss: 47.81766933865018, Validation Loss: 191.61908117930093
Epoch 11/30, Training Loss: 47.524300469292534, Validation Loss: 191.0384251276652
Epoch 12/30, Training Loss: 46.822449662950305, Validation Loss: 191.2440538406372
Epoch 13/30, Training Loss: 45.9994388156467, Validation Loss: 190.5266834894816
Epoch 14/30, Training Loss: 45.582711622450084, Validation Loss: 188.55037212371826
Epoch 15/30, Training Loss: 44.89894324408637, Validation Loss: 187.62296676635742
Epoch 16/30, Training Loss: 44.38909420437283, Validation Loss: 190.2049217224121
Epoch 17/30, Training Loss: 44.68807390001085, Validation Loss: 189.20528157552084
Epoch 18/30, Training Loss: 43.929205830891924, Validation Loss: 189.91830094655356
Epoch 19/30, Training Loss: 43.186015065511064, Validation Loss: 188.0519240697225
Epoch 20/30, Training Loss: 42.362147352430554, Validation Loss: 186.8313970565796
Epoch 21/30, Training Loss: 42.51196111043294, Validation Loss: 186.515017191569
Epoch 22/30, Training Loss: 41.920947265625, Validation Loss: 189.6102253595988
Epoch 23/30, Training Loss: 41.618788146972655, Validation Loss: 187.03820164998373
Epoch 24/30, Training Loss: 41.889136505126956, Validation Loss: 183.9749552408854
Epoch 25/30, Training Loss: 41.11484959920247, Validation Loss: 187.04142888387045
Epoch 26/30, Training Loss: 41.106873491075305, Validation Loss: 185.5558303197225
Epoch 27/30, Training Loss: 40.58207711113824, Validation Loss: 190.92296314239502
Epoch 28/30, Training Loss: 40.81068674723307, Validation Loss: 193.93083826700845
Epoch 29/30, Training Loss: 40.431168450249565, Validation Loss: 188.75267124176025
Epoch 30/30, Training Loss: 39.85402603149414, Validation Loss: 184.7258857091268
Training complete
Training for hidden_size=200, blocks = 2
Epoch 1/30, Training Loss: 216.53725297715928, Validation Loss: 444.19052505493164
Epoch 2/30, Training Loss: 113.37141401502821, Validation Loss: 364.4847469329834
Epoch 3/30, Training Loss: 100.15595075819228, Validation Loss: 345.93261528015137
Epoch 4/30, Training Loss: 94.53593071831597, Validation Loss: 331.55730120340985
Epoch 5/30, Training Loss: 92.21392093234591, Validation Loss: 324.0475031534831
Epoch 6/30, Training Loss: 88.54855431450738, Validation Loss: 308.84101994832355
Epoch 7/30, Training Loss: 86.4598612467448, Validation Loss: 309.0832977294922
Epoch 8/30, Training Loss: 84.3639392428928, Validation Loss: 306.3279215494792
Epoch 9/30, Training Loss: 83.28580000135634, Validation Loss: 311.41941324869794
Epoch 10/30, Training Loss: 81.50191514756945, Validation Loss: 296.9877471923828
Epoch 11/30, Training Loss: 79.26307542588975, Validation Loss: 278.895850499471
Epoch 12/30, Training Loss: 78.94070926242405, Validation Loss: 271.9144166310628
Epoch 13/30, Training Loss: 78.2772957695855, Validation Loss: 293.2340513865153
Epoch 14/30, Training Loss: 78.42385457356771, Validation Loss: 287.2551898956299
Epoch 15/30, Training Loss: 76.45130123562284, Validation Loss: 274.488676071167
Epoch 16/30, Training Loss: 76.47469346788195, Validation Loss: 276.59137535095215
Epoch 17/30, Training Loss: 75.16443176269532, Validation Loss: 275.96004931132
Epoch 18/30, Training Loss: 76.47992943657769, Validation Loss: 270.53071784973145
Epoch 19/30, Training Loss: 74.54247843424479, Validation Loss: 273.3419183095296
Epoch 20/30, Training Loss: 74.45988430447049, Validation Loss: 276.253563563029
Epoch 21/30, Training Loss: 74.91201510959202, Validation Loss: 280.04852358500165
Epoch 22/30, Training Loss: 74.06419660780165, Validation Loss: 261.0175202687581
Epoch 23/30, Training Loss: 73.2760986328125, Validation Loss: 266.8240426381429
Epoch 24/30, Training Loss: 73.63587934705946, Validation Loss: 268.8825174967448
Epoch 25/30, Training Loss: 72.39171634250216, Validation Loss: 256.8872324625651
Epoch 26/30, Training Loss: 71.26632826063368, Validation Loss: 260.2430502573649
Epoch 27/30, Training Loss: 72.33309190538195, Validation Loss: 263.19954744974774
Epoch 28/30, Training Loss: 71.85370229085287, Validation Loss: 264.2111365000407
Epoch 29/30, Training Loss: 70.93008236355251, Validation Loss: 263.0232054392497
Epoch 30/30, Training Loss: 71.94356502956815, Validation Loss: 258.7958990732829
Training complete
Training for hidden_size=200, blocks = 5
Epoch 1/30, Training Loss: 309.551146613227, Validation Loss: 296.65241622924805
Epoch 2/30, Training Loss: 77.70625271267362, Validation Loss: 259.46020062764484
Epoch 3/30, Training Loss: 69.53970616658529, Validation Loss: 240.80973148345947
Epoch 4/30, Training Loss: 65.07096065945096, Validation Loss: 229.57966740926108
Epoch 5/30, Training Loss: 62.08271713256836, Validation Loss: 223.40517075856528
Epoch 6/30, Training Loss: 59.97387279934353, Validation Loss: 221.47677834828696
Epoch 7/30, Training Loss: 58.47309239705404, Validation Loss: 217.21256033579508
Epoch 8/30, Training Loss: 56.85402018229167, Validation Loss: 214.76035340627035
Epoch 9/30, Training Loss: 55.775561014811196, Validation Loss: 210.1625213623047
Epoch 10/30, Training Loss: 54.69895239935981, Validation Loss: 212.7543576558431
Epoch 11/30, Training Loss: 54.13418511284722, Validation Loss: 206.8346061706543
Epoch 12/30, Training Loss: 53.188574133978946, Validation Loss: 206.58446153004965
Epoch 13/30, Training Loss: 52.80375417073568, Validation Loss: 209.3851442337036
Epoch 14/30, Training Loss: 52.33621393839518, Validation Loss: 202.7700351079305
Epoch 15/30, Training Loss: 51.24921044243707, Validation Loss: 200.8417387008667
Epoch 16/30, Training Loss: 50.287712012396916, Validation Loss: 201.89845403035483
Epoch 17/30, Training Loss: 49.85312194824219, Validation Loss: 202.5632349650065
Epoch 18/30, Training Loss: 49.33689439561632, Validation Loss: 199.69777806599936
Epoch 19/30, Training Loss: 49.72008887396918, Validation Loss: 204.82231680552164
Epoch 20/30, Training Loss: 49.121662309434676, Validation Loss: 202.7250280380249
Epoch 21/30, Training Loss: 49.025482940673825, Validation Loss: 206.9731995264689
Epoch 22/30, Training Loss: 48.66969256930881, Validation Loss: 202.8100382486979
Epoch 23/30, Training Loss: 47.902609168158634, Validation Loss: 202.15858459472656
Epoch 24/30, Training Loss: 47.533621554904514, Validation Loss: 199.37864557902017
Epoch 25/30, Training Loss: 46.99391530354818, Validation Loss: 201.3996795018514
Epoch 26/30, Training Loss: 47.13845409817166, Validation Loss: 200.77649021148682
Epoch 27/30, Training Loss: 47.37102983262804, Validation Loss: 199.04171403249106
Epoch 28/30, Training Loss: 46.85574535793728, Validation Loss: 204.1581137975057
Epoch 29/30, Training Loss: 46.16206520928277, Validation Loss: 201.75563176472983
Epoch 30/30, Training Loss: 46.21358116997613, Validation Loss: 202.40555699666342
Training complete
Training for hidden_size=200, blocks = 10
Epoch 1/30, Training Loss: 6492.414649454752, Validation Loss: 282.70446332295734
Epoch 2/30, Training Loss: 73.40664850870768, Validation Loss: 246.55630683898926
Epoch 3/30, Training Loss: 64.3135969373915, Validation Loss: 226.65690644582114
Epoch 4/30, Training Loss: 60.09317186143663, Validation Loss: 215.6777229309082
Epoch 5/30, Training Loss: 57.14421852959527, Validation Loss: 210.23212718963623
Epoch 6/30, Training Loss: 55.430495198567705, Validation Loss: 204.4380890528361
Epoch 7/30, Training Loss: 53.62228147718641, Validation Loss: 208.22985331217447
Epoch 8/30, Training Loss: 52.40490570068359, Validation Loss: 201.35067780812582
Epoch 9/30, Training Loss: 51.76244778103299, Validation Loss: 199.5050137837728
Epoch 10/30, Training Loss: 50.03342141045464, Validation Loss: 195.68669923146567
Epoch 11/30, Training Loss: 49.121512010362416, Validation Loss: 197.29208914438883
Epoch 12/30, Training Loss: 48.897773827446834, Validation Loss: 196.3861207962036
Epoch 13/30, Training Loss: 48.12302941216363, Validation Loss: 196.06686433156332
Epoch 14/30, Training Loss: 47.333321211073134, Validation Loss: 191.7859255472819
Epoch 15/30, Training Loss: 47.04701470269097, Validation Loss: 194.23813724517822
Epoch 16/30, Training Loss: 46.205683898925784, Validation Loss: 193.08368174235025
Epoch 17/30, Training Loss: 45.86734517415365, Validation Loss: 194.6445223490397
Epoch 18/30, Training Loss: 45.691319190131296, Validation Loss: 190.83626715342203
Epoch 19/30, Training Loss: 44.578766123453775, Validation Loss: 198.11820379892984
Epoch 20/30, Training Loss: 44.48360688951281, Validation Loss: 196.22168699900308
Epoch 21/30, Training Loss: 44.72951049804688, Validation Loss: 192.11080837249756
Epoch 22/30, Training Loss: 44.1974353366428, Validation Loss: 193.88198947906494
Epoch 23/30, Training Loss: 44.18268771701389, Validation Loss: 196.84188238779703
Epoch 24/30, Training Loss: 43.73227649264865, Validation Loss: 193.48014958699545
Epoch 25/30, Training Loss: 43.711177571614584, Validation Loss: 194.7969299952189
Epoch 26/30, Training Loss: 42.93069229125977, Validation Loss: 191.0793244043986
Epoch 27/30, Training Loss: 42.515458255343965, Validation Loss: 191.73682816823325
Epoch 28/30, Training Loss: 43.338741048177084, Validation Loss: 193.38479391733804
Epoch 29/30, Training Loss: 42.78940794203017, Validation Loss: 194.90061601003012
Epoch 30/30, Training Loss: 42.59330757988824, Validation Loss: 195.17180665334067
Training complete
Training for hidden_size=400, blocks = 2
Epoch 1/30, Training Loss: 307.47569749620226, Validation Loss: 660.282797495524
Epoch 2/30, Training Loss: 180.9752634684245, Validation Loss: 812.9520988464355
Epoch 3/30, Training Loss: 173.47369859483507, Validation Loss: 569.1646308898926
Epoch 4/30, Training Loss: 159.17198621961805, Validation Loss: 565.31818262736
Epoch 5/30, Training Loss: 158.90091518825955, Validation Loss: 551.1031494140625
Epoch 6/30, Training Loss: 159.60764600965712, Validation Loss: 574.1276728312174
Epoch 7/30, Training Loss: 145.9571780734592, Validation Loss: 495.2603123982747
Epoch 8/30, Training Loss: 142.371486070421, Validation Loss: 491.8583634694417
Epoch 9/30, Training Loss: 141.50491773817274, Validation Loss: 498.96835072835285
Epoch 10/30, Training Loss: 143.82400377061632, Validation Loss: 492.4005317687988
Epoch 11/30, Training Loss: 146.19541456434462, Validation Loss: 574.3242225646973
Epoch 12/30, Training Loss: 145.28963758680555, Validation Loss: 528.7288983662924
Epoch 13/30, Training Loss: 136.63750644259983, Validation Loss: 515.2850685119629
Epoch 14/30, Training Loss: 142.1431399875217, Validation Loss: 489.8877480824788
Epoch 15/30, Training Loss: 142.1147196451823, Validation Loss: 474.4272352854411
Epoch 16/30, Training Loss: 140.7830559624566, Validation Loss: 495.7752113342285
Epoch 17/30, Training Loss: 147.55852644178603, Validation Loss: 470.71923510233563
Epoch 18/30, Training Loss: 141.1549747043186, Validation Loss: 539.7263113657633
Epoch 19/30, Training Loss: 135.7890879313151, Validation Loss: 485.22253545125324
Epoch 20/30, Training Loss: 135.64784579806857, Validation Loss: 490.24193954467773
Epoch 21/30, Training Loss: 130.6491214328342, Validation Loss: 452.17605781555176
Epoch 22/30, Training Loss: 126.53772633870443, Validation Loss: 451.6606750488281
Epoch 23/30, Training Loss: 127.67455647786458, Validation Loss: 449.65537707010907
Epoch 24/30, Training Loss: 127.42182057698568, Validation Loss: 488.5056915283203
Epoch 25/30, Training Loss: 127.36804436577691, Validation Loss: 496.6194330851237
Epoch 26/30, Training Loss: 133.38234303792316, Validation Loss: 489.46679941813153
Epoch 27/30, Training Loss: 130.7207258436415, Validation Loss: 513.3341700236002
Epoch 28/30, Training Loss: 127.46561448838976, Validation Loss: 457.0716412862142
Epoch 29/30, Training Loss: 125.12691226535374, Validation Loss: 442.3211104075114
Epoch 30/30, Training Loss: 124.04728088378906, Validation Loss: 441.13921610514325
Training complete
Training for hidden_size=400, blocks = 5
Epoch 1/30, Training Loss: 316.47518547905815, Validation Loss: 309.94998868306476
Epoch 2/30, Training Loss: 82.13634084065755, Validation Loss: 275.5159142812093
Epoch 3/30, Training Loss: 74.84136166042752, Validation Loss: 262.6661961873372
Epoch 4/30, Training Loss: 70.7937240600586, Validation Loss: 249.78180122375488
Epoch 5/30, Training Loss: 68.1334605746799, Validation Loss: 245.63199742635092
Epoch 6/30, Training Loss: 65.83484369913737, Validation Loss: 234.56178633371988
Epoch 7/30, Training Loss: 63.888429175482855, Validation Loss: 240.55420271555582
Epoch 8/30, Training Loss: 62.737112765842014, Validation Loss: 235.59585825602213
Epoch 9/30, Training Loss: 62.29926893446181, Validation Loss: 233.10778681437174
Epoch 10/30, Training Loss: 61.70500666300456, Validation Loss: 228.81559340159097
Epoch 11/30, Training Loss: 60.07350964016385, Validation Loss: 227.86209615071616
Epoch 12/30, Training Loss: 59.456063503689236, Validation Loss: 226.61377970377603
Epoch 13/30, Training Loss: 58.891603766547306, Validation Loss: 227.92042700449625
Epoch 14/30, Training Loss: 58.36316468980577, Validation Loss: 224.34271812438965
Epoch 15/30, Training Loss: 57.4016359117296, Validation Loss: 224.10171476999918
Epoch 16/30, Training Loss: 58.129408518473305, Validation Loss: 221.11990706125894
Epoch 17/30, Training Loss: 57.24721306694879, Validation Loss: 227.1364345550537
Epoch 18/30, Training Loss: 56.437878672281904, Validation Loss: 224.80929819742838
Epoch 19/30, Training Loss: 56.5604724460178, Validation Loss: 220.81223583221436
Epoch 20/30, Training Loss: 54.89689534505208, Validation Loss: 220.16179784138998
Epoch 21/30, Training Loss: 54.49175864325629, Validation Loss: 223.90884653727213
Epoch 22/30, Training Loss: 54.91522538926866, Validation Loss: 221.36671447753906
Epoch 23/30, Training Loss: 54.64884719848633, Validation Loss: 224.3624588648478
Epoch 24/30, Training Loss: 54.53467610677083, Validation Loss: 219.23523680369058
Epoch 25/30, Training Loss: 54.39541456434462, Validation Loss: 225.01481914520264
Epoch 26/30, Training Loss: 53.75576400756836, Validation Loss: 224.30957667032877
Epoch 27/30, Training Loss: 53.983888668484155, Validation Loss: 222.11862341562906
Epoch 28/30, Training Loss: 53.86955354478624, Validation Loss: 225.67633787790933
Epoch 29/30, Training Loss: 53.31535432603624, Validation Loss: 224.4200760523478
Epoch 30/30, Training Loss: 52.715677218967016, Validation Loss: 223.37881247202554
Training complete
Training for hidden_size=400, blocks = 10
Epoch 1/30, Training Loss: 13539.159749179416, Validation Loss: 342.41082700093585
Epoch 2/30, Training Loss: 86.58183034261067, Validation Loss: 278.1364510854085
Epoch 3/30, Training Loss: 75.3851586235894, Validation Loss: 259.4295717875163
Epoch 4/30, Training Loss: 70.06561041937934, Validation Loss: 287.4711570739746
Epoch 5/30, Training Loss: 67.08853412204319, Validation Loss: 243.05325508117676
Epoch 6/30, Training Loss: 64.71018481784397, Validation Loss: 239.3382879892985
Epoch 7/30, Training Loss: 63.28148990207248, Validation Loss: 233.9927625656128
Epoch 8/30, Training Loss: 61.262958187527126, Validation Loss: 225.98743216196695
Epoch 9/30, Training Loss: 60.458206515842015, Validation Loss: 230.18698183695474
Epoch 10/30, Training Loss: 58.9208381652832, Validation Loss: 224.2623545328776
Epoch 11/30, Training Loss: 58.24448852539062, Validation Loss: 221.45679410298666
Epoch 12/30, Training Loss: 56.951568094889325, Validation Loss: 219.85390186309814
Epoch 13/30, Training Loss: 56.31669930352105, Validation Loss: 219.88257439931235
Epoch 14/30, Training Loss: 55.53523585001628, Validation Loss: 222.5670598347982
Epoch 15/30, Training Loss: 54.74429762098524, Validation Loss: 213.9870694478353
Epoch 16/30, Training Loss: 54.23894585503472, Validation Loss: 224.2745631535848
Epoch 17/30, Training Loss: 53.661950005425346, Validation Loss: 217.9786860148112
Epoch 18/30, Training Loss: 54.17339884440104, Validation Loss: 217.71878623962402
Epoch 19/30, Training Loss: 53.162393188476564, Validation Loss: 213.33333587646484
Epoch 20/30, Training Loss: 52.72771631876628, Validation Loss: 216.45318190256754
Epoch 21/30, Training Loss: 53.03551466200087, Validation Loss: 219.1963389714559
Epoch 22/30, Training Loss: 53.08130594889323, Validation Loss: 221.36446475982666
Epoch 23/30, Training Loss: 51.74025141398112, Validation Loss: 217.7514041264852
Epoch 24/30, Training Loss: 51.72227588229709, Validation Loss: 220.6808811823527
Epoch 25/30, Training Loss: 51.49509489271376, Validation Loss: 213.47276719411215
Epoch 26/30, Training Loss: 51.085971408420136, Validation Loss: 215.52674261728922
Epoch 27/30, Training Loss: 50.79202465481228, Validation Loss: 212.3214489618937
Epoch 28/30, Training Loss: 50.732866838243275, Validation Loss: 212.61949157714844
Epoch 29/30, Training Loss: 50.685603586832684, Validation Loss: 212.96100012461343
Epoch 30/30, Training Loss: 50.355876753065324, Validation Loss: 218.3713518778483
Training complete

We see, the best network structure appears to be hidden size of 100 and 10 blocks

3.2 Adding artificial bottleneck¶

In [ ]:
class RealNVP_bottleneck(nn.Module):
    def __init__(self, input_size, hidden_size, blocks, k):
        super(RealNVP_bottleneck, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.blocks = blocks
        self.k = k

        # List of coupling layers
        self.coupling_layers = nn.ModuleList([
            CouplingLayer(input_size, hidden_size) for _ in range(blocks)
        ])


        # List to store orthonormal matrices
        self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]

        # List to store scaling_before_exp for each block
        self.scaling_before_exp_list = []

    def _get_orthonormal_matrix(self, size):
        # Function to generate a random orthonormal matrix
        w = torch.randn(size, size)
        q, _ = torch.linalg.qr(w,'reduced')
        return q

    def forward_realnvp(self, x):
        scaling_before_exp_list = []
        for i in range(self.blocks):

            # Apply random orthonormal matrix
            x = torch.matmul(x, self.orthonormal_matrices[i])

            # Apply coupling layer
            x, scaling_before_exp = self.coupling_layers[i].forward(x)
            scaling_before_exp_list.append(scaling_before_exp)

        self.scaling_before_exp_list = scaling_before_exp_list
        return x

    def encode(self, x):
        # Encoding is the forward pass through the RealNVP model
        return self.forward_realnvp(x)

    def decode(self, z):
        # Modify z to zero out dimensions beyond k for the reconstruction
        z_reconstructed = z.clone()
        if self.k < self.input_size:
            z_reconstructed[:, self.k:] = 0  # Zero out dimensions beyond k

         # Proceed with the original decoding process
        for i in reversed(range(self.blocks)):
            z = self.coupling_layers[i].backward(z)
            z_reconstructed = self.coupling_layers[i].backward(z_reconstructed)
            z = torch.matmul(z, self.orthonormal_matrices[i].t())
            z_reconstructed = torch.matmul(z_reconstructed, self.orthonormal_matrices[i].t())

        return z, z_reconstructed

    def sample(self, num_samples=1000):
        # Generate random samples from a standard normal distribution
        with torch.no_grad():
            z = torch.randn(num_samples, self.input_size)

        # Apply the reverse transformations (decoder) to generate synthetic samples
        _,synthetic_samples = self.decode(z)
        return synthetic_samples
    
    def sample_only_important(self, num_samples=1000):
        # Generate random samples from a standard normal distribution
        with torch.no_grad():
            z_1 = torch.randn(num_samples, self.k)
            z_2 = torch.zeros(num_samples, self.input_size - self.k)
            z = torch.cat((z_1, z_2), dim=1)

        # Apply the reverse transformations (decoder) to generate synthetic samples
        _,synthetic_samples = self.decode(z)
        return synthetic_samples
    
    def sample_only_unimportant(self, num_samples=1000):
        # Generate random samples from a standard normal distribution

        with torch.no_grad():
            z_1 = torch.randn(1, self.k).repeat(num_samples, 1)
            z_2 = torch.randn(num_samples, self.input_size - self.k)
            z = torch.cat((z_1, z_2), dim=1)

        # Apply the reverse transformations (decoder) to generate synthetic samples
        _,synthetic_samples = self.decode(z)
        return synthetic_samples
In [ ]:
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1):
    """
    Train the RealNVP model and evaluate on a validation dataset.

    Args:
    - model (RealNVP): The RealNVP model to be trained.
    - train_loader (DataLoader): DataLoader for the training dataset.
    - val_loader (DataLoader): DataLoader for the validation dataset.
    - num_epochs (int): Number of training epochs.
    - lr (float): Learning rate for the optimizer.
    - print_after (int): Number of epochs after which to print the training and validation loss.

    Returns:
    - train_losses (list): List of training losses for each epoch.
    - val_losses (list): List of validation losses for each epoch.
    """

    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    train_losses_nll = [] 
    val_losses_nll = []    
    train_losses_recons = []  
    val_losses_recons = []    
    # Training phase
    model.train()  # Set the model to training mode
    
    for epoch in range(num_epochs):
        total_train_loss_nll = 0.0
        total_train_loss_recons = 0.0

        for data in train_loader:
            inputs= data

            # Zero the gradients
            optimizer.zero_grad()

            # NLL Loss calculation
            encoded = model.encode(inputs)
            train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
            
            # Reconstruction loss calculation
            _, decoded = model.decode(encoded)
            train_loss_recons = mse_loss(inputs, decoded)
            
            # Backward pass (gradient computation)
            loss = train_loss_nll + train_loss_recons
            loss.backward()

            ### added recently: clip the gradients
            clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            # Update weights
            optimizer.step()

            total_train_loss_nll += train_loss_nll.item()
            total_train_loss_recons += train_loss_recons.item()

        # Average training loss for the epoch
        average_train_loss_nll = total_train_loss_nll / len(train_loader)
        average_train_loss_recons = total_train_loss_recons / len(train_loader)

        # Validation phase
        model.eval()  # Set the model to evaluation mode
        if val_loader is not None:
            total_val_loss_nll = 0.0
            total_val_loss_recons = 0.0
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs = val_data

                    # NLL Loss calculation
                    encoded = model.encode(val_inputs)
                    val_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(val_loader))
                    
                    # Reconstruction loss calculation
                    _, decoded = model.decode(encoded)
                    val_loss_recons = mse_loss(val_inputs, decoded)

                    total_val_loss_nll += val_loss_nll.item()
                    total_val_loss_recons += val_loss_recons.item()

            # Average validation loss for the epoch
            average_val_loss_nll = total_val_loss_nll / len(val_loader)
            average_val_loss_recons = total_val_loss_recons / len(val_loader)

            # Print training and validation losses together
            if (epoch + 1) % print_after == 0:
                print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")

            # Append losses to the lists
            train_losses_nll.append(average_train_loss_nll)
            val_losses_nll.append(average_val_loss_nll)
            train_losses_recons.append(average_train_loss_recons)
            val_losses_recons.append(average_val_loss_recons)


        # Set the model back to training mode
        model.train()

    print("Training complete")

    return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
In [ ]:
k_values = [2,4,8]
dataset_percentage = 1.0
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)

for k in k_values:
    print(f"\nTraining for k={k}")

    # Instantiate the model
    model = RealNVP_bottleneck(input_size=64, hidden_size=100, blocks=10,k=k)

    # Train the model
    train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
    train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
    val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
    # plotting the loss
    plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
    plt.show()

    # Example usage:
    plot_code_distribution(model=model, test_loader=val_loader)
    plt.show()

    ### plot the synthetic data
    synthetic_data=model.sample(num_samples=len(data_considered))
    visualize_synthetic_data(synthetic_data)
    
Training for k=2
Epoch 1/20, Training Loss: 1134181.2054025438, Validation Loss: 393.4980107943217
Epoch 2/20, Training Loss: 104.72341115739611, Validation Loss: 277.5674916903178
Epoch 3/20, Training Loss: 85.00006866455078, Validation Loss: 252.0516587893168
Epoch 4/20, Training Loss: 77.07950909932454, Validation Loss: 236.80850172042847
Epoch 5/20, Training Loss: 72.53351398044163, Validation Loss: 228.5046550432841
Epoch 6/20, Training Loss: 69.57301697201198, Validation Loss: 222.44894289970398
Epoch 7/20, Training Loss: 67.61917934417724, Validation Loss: 218.56156961123148
Epoch 8/20, Training Loss: 65.83680402967664, Validation Loss: 215.66464479764304
Epoch 9/20, Training Loss: 64.46760686238606, Validation Loss: 212.84963099161783
Epoch 10/20, Training Loss: 63.096194140116374, Validation Loss: 213.09209056695303
Epoch 11/20, Training Loss: 62.07158679962158, Validation Loss: 207.5499466260274
Epoch 12/20, Training Loss: 61.22519813113742, Validation Loss: 212.94606955846152
Epoch 13/20, Training Loss: 60.01767086452908, Validation Loss: 206.5381192366282
Epoch 14/20, Training Loss: 59.57084166208903, Validation Loss: 203.71867847442627
Epoch 15/20, Training Loss: 58.656814617580835, Validation Loss: 205.08861184120178
Epoch 16/20, Training Loss: 58.426832644144696, Validation Loss: 205.47613739967346
Epoch 17/20, Training Loss: 57.60618413289388, Validation Loss: 204.82380390167236
Epoch 18/20, Training Loss: 56.88742605845133, Validation Loss: 205.32237219810486
Epoch 19/20, Training Loss: 56.780501280890576, Validation Loss: 206.95415838559467
Epoch 20/20, Training Loss: 56.04133735232883, Validation Loss: 202.103222211202
Training complete
Training for k=4
Epoch 1/20, Training Loss: 1764351.6226888022, Validation Loss: 6121.43034807841
Epoch 2/20, Training Loss: 305.77129427591956, Validation Loss: 328.2409567832947
Epoch 3/20, Training Loss: 93.7455064561632, Validation Loss: 274.5516738096873
Epoch 4/20, Training Loss: 80.29229000939263, Validation Loss: 248.36701107025146
Epoch 5/20, Training Loss: 74.14117916954888, Validation Loss: 236.20576540629068
Epoch 6/20, Training Loss: 69.56096011267768, Validation Loss: 228.18267114957172
Epoch 7/20, Training Loss: 66.54357260598077, Validation Loss: 219.3884553114573
Epoch 8/20, Training Loss: 64.57242608600193, Validation Loss: 215.73117335637411
Epoch 9/20, Training Loss: 62.44744293424819, Validation Loss: 211.98541418711343
Epoch 10/20, Training Loss: 60.773972034454346, Validation Loss: 208.40398891766867
Epoch 11/20, Training Loss: 59.45615416632758, Validation Loss: 205.69273841381073
Epoch 12/20, Training Loss: 58.051778125762944, Validation Loss: 206.46651673316956
Epoch 13/20, Training Loss: 57.31363100475735, Validation Loss: 202.31709150473276
Epoch 14/20, Training Loss: 56.459687858157686, Validation Loss: 203.77659797668457
Epoch 15/20, Training Loss: 55.96581435733371, Validation Loss: 201.2215979496638
Epoch 16/20, Training Loss: 55.15948048697578, Validation Loss: 199.77739560604095
Epoch 17/20, Training Loss: 55.31225707795885, Validation Loss: 200.20230305194855
Epoch 18/20, Training Loss: 54.43091343773736, Validation Loss: 199.18999723593393
Epoch 19/20, Training Loss: 53.94967920515273, Validation Loss: 193.5152560075124
Epoch 20/20, Training Loss: 53.28515498903063, Validation Loss: 197.675413052241
Training complete
Training for k=8
Epoch 1/20, Training Loss: 3957463.0584517587, Validation Loss: 678.2395188013713
Epoch 2/20, Training Loss: 127.46482734680175, Validation Loss: 297.8124193350474
Epoch 3/20, Training Loss: 85.34479524824354, Validation Loss: 255.7841518719991
Epoch 4/20, Training Loss: 75.56492190890842, Validation Loss: 241.27586038907367
Epoch 5/20, Training Loss: 70.02874931759305, Validation Loss: 227.84788632392883
Epoch 6/20, Training Loss: 66.06599095662435, Validation Loss: 219.99094394842783
Epoch 7/20, Training Loss: 63.42344832950168, Validation Loss: 215.41912790139514
Epoch 8/20, Training Loss: 61.34067319234212, Validation Loss: 213.2213078737259
Epoch 9/20, Training Loss: 59.875758753882515, Validation Loss: 210.1514040629069
Epoch 10/20, Training Loss: 58.376855532328285, Validation Loss: 208.3204576174418
Epoch 11/20, Training Loss: 56.99812969631619, Validation Loss: 199.2074755827586
Epoch 12/20, Training Loss: 55.90390529632568, Validation Loss: 201.8850393295288
Epoch 13/20, Training Loss: 55.46685161590577, Validation Loss: 201.70479098955792
Epoch 14/20, Training Loss: 54.837560865614144, Validation Loss: 200.57280039787292
Epoch 15/20, Training Loss: 54.00313622156779, Validation Loss: 201.01371534665427
Epoch 16/20, Training Loss: 53.054822762807206, Validation Loss: 200.3045927286148
Epoch 17/20, Training Loss: 52.68340014351739, Validation Loss: 200.4621553023656
Epoch 18/20, Training Loss: 52.258125665452745, Validation Loss: 198.67079102993011
Epoch 19/20, Training Loss: 51.502514394124354, Validation Loss: 199.64668464660645
Epoch 20/20, Training Loss: 51.13463984595405, Validation Loss: 200.92397185166678
Training complete

We see the reconstruction performs already far better with the bottleneck. We get the best results with k = 4. So lets try two different sampling techniques

In [ ]:
input_size = 64
hidden_size = 100
blocks = 10
print_after=1
dataset_percentage = 1.0
batch_size=32
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)
# Instantiate the model
model = RealNVP_bottleneck(input_size=input_size, hidden_size=hidden_size, blocks=blocks,k=4)

# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)

### plot the synthetic data
synthetic_data=model.sample_only_important(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data, title="Sampling only important features")

synthetic_data=model.sample_only_unimportant(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data, title="Sampling only unimportant features")
    
Epoch 1/20, Training Loss: 506830.6718499078, Validation Loss: 428.4326847394308
Epoch 2/20, Training Loss: 107.78536251915827, Validation Loss: 291.4867718219757
Epoch 3/20, Training Loss: 85.82897991604275, Validation Loss: 259.36583797136944
Epoch 4/20, Training Loss: 77.02423313988581, Validation Loss: 241.9788273970286
Epoch 5/20, Training Loss: 71.95412707858615, Validation Loss: 232.03203002611795
Epoch 6/20, Training Loss: 68.29618448681302, Validation Loss: 227.44942140579224
Epoch 7/20, Training Loss: 65.99424372778998, Validation Loss: 224.72661836942038
Epoch 8/20, Training Loss: 63.48433694839478, Validation Loss: 218.69628206888834
Epoch 9/20, Training Loss: 61.90843456056383, Validation Loss: 216.30499251683554
Epoch 10/20, Training Loss: 60.55102155473497, Validation Loss: 214.02308400472006
Epoch 11/20, Training Loss: 59.381386015150284, Validation Loss: 207.5036437511444
Epoch 12/20, Training Loss: 58.41290695402357, Validation Loss: 206.32842751344046
Epoch 13/20, Training Loss: 57.291498226589624, Validation Loss: 206.07214486598969
Epoch 14/20, Training Loss: 57.13765395482381, Validation Loss: 205.6906545559565
Epoch 15/20, Training Loss: 56.12033676571316, Validation Loss: 209.88612131277722
Epoch 16/20, Training Loss: 55.83825403849284, Validation Loss: 204.41844717661542
Epoch 17/20, Training Loss: 54.830014557308616, Validation Loss: 198.95498553911847
Epoch 18/20, Training Loss: 54.01373918321398, Validation Loss: 203.80519477526346
Epoch 19/20, Training Loss: 53.89666982226902, Validation Loss: 198.23216180006662
Epoch 20/20, Training Loss: 53.23209324942695, Validation Loss: 199.4986126422882
Training complete

We see the algorithm performs as expected, when we sample the important features, the numbers/complete images changes. If we only sample the unimportant features we see only small changes in the printed images

3.3 RealNVP with MNIST¶

In [ ]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter # install 'pillow ' to get PIL
import matplotlib.pyplot as plt
# define a functor to downsample images
class DownsampleTransform:
    def __init__ (self , target_shape , algorithm = Image.Resampling.LANCZOS):
        self.width, self.height = target_shape
        self.algorithm = algorithm
    def __call__ (self , img):
        img = img.resize(( self.width +2, self.height +2) , self.algorithm )
        img = img.crop((1 , 1, self.width +1, self.height +1))
        return img
# concatenate a few transforms
transform = transforms.Compose ([
    DownsampleTransform(target_shape =(8 ,8)),
    transforms.Grayscale(num_output_channels =1) ,
    transforms.ToTensor()
])
# download MNIST
mnist_dataset = datasets.MNIST( root ='./data', train =True ,
                                  transform = transform, download = True )
# create a DataLoader that serves minibatches of size 100
data_loader = DataLoader(mnist_dataset , batch_size =100 , shuffle = True )

mnist_test_dataset = datasets.MNIST( root ='./data', train =False ,
                                  transform = transform, download = True )
val_loader = DataLoader(mnist_test_dataset , batch_size =100 , shuffle = True )
# visualize the first batch of downsampled MNIST images
def show_first_batch(data_loader):
    for batch in data_loader:
        x, y = batch
        fig = plt.figure(figsize =(10 , 10))
        for i, img in enumerate(x):
            ax = fig.add_subplot(10 , 10, i+1)
            ax.imshow(img.reshape(8, 8), cmap ='gray')
            ax.axis('off')
        break



show_first_batch(data_loader)
In [ ]:
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1):
    """
    Train the RealNVP model and evaluate on a validation dataset.

    Args:
    - model (RealNVP): The RealNVP model to be trained.
    - train_loader (DataLoader): DataLoader for the training dataset.
    - val_loader (DataLoader): DataLoader for the validation dataset.
    - num_epochs (int): Number of training epochs.
    - lr (float): Learning rate for the optimizer.
    - print_after (int): Number of epochs after which to print the training and validation loss.

    Returns:
    - train_losses (list): List of training losses for each epoch.
    - val_losses (list): List of validation losses for each epoch.
    """

    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    train_losses_nll = [] 
    val_losses_nll = []    
    train_losses_recons = []  
    val_losses_recons = []    
    # Training phase
    model.train()  # Set the model to training mode
    
    for epoch in range(num_epochs):
        total_train_loss_nll = 0.0
        total_train_loss_recons = 0.0

        for batch in train_loader:
            X, y = batch
            inputs= X.reshape(len(y),64)

            # Zero the gradients
            optimizer.zero_grad()

            # NLL Loss calculation
            encoded = model.encode(inputs)
            train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
            
            # Reconstruction loss calculation
            decoded = model.decode(encoded)
            train_loss_recons = mse_loss(inputs, decoded)
            
            # Backward pass (gradient computation)
            loss = train_loss_nll + train_loss_recons
            loss.backward()

            ### added recently: clip the gradients
            clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            # Update weights
            optimizer.step()

            total_train_loss_nll += train_loss_nll.item()
            total_train_loss_recons += train_loss_recons.item()

        # Average training loss for the epoch
        average_train_loss_nll = total_train_loss_nll / len(train_loader)
        average_train_loss_recons = total_train_loss_recons / len(train_loader)

        # Validation phase
        model.eval()  # Set the model to evaluation mode
        if val_loader is not None:
            total_val_loss_nll = 0.0
            total_val_loss_recons = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    X,y = batch
                    val_inputs = X.reshape(len(y),64)

                    # NLL Loss calculation
                    encoded = model.encode(val_inputs)
                    val_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(val_loader))
                    
                    # Reconstruction loss calculation
                    decoded = model.decode(encoded)
                    val_loss_recons = mse_loss(val_inputs, decoded)

                    total_val_loss_nll += val_loss_nll.item()
                    total_val_loss_recons += val_loss_recons.item()

            # Average validation loss for the epoch
            average_val_loss_nll = total_val_loss_nll / len(val_loader)
            average_val_loss_recons = total_val_loss_recons / len(val_loader)

            # Print training and validation losses together
            if (epoch + 1) % print_after == 0:
                print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")

            # Append losses to the lists
            train_losses_nll.append(average_train_loss_nll)
            val_losses_nll.append(average_val_loss_nll)
            train_losses_recons.append(average_train_loss_recons)
            val_losses_recons.append(average_val_loss_recons)


        # Set the model back to training mode
        model.train()

    print("Training complete")

    return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
In [ ]:
def plot_code_distribution(model, test_loader):
    """
    Plot the code distribution obtained by applying the trained RealNVP model to a test dataset.

    Args:
    - model (RealNVP): Trained RealNVP model.
    - test_loader (DataLoader): DataLoader for the test dataset.
    - num_samples (int): Number of samples to visualize.

    Returns:
    None (displays the plot).
    """
    model.eval()  # Set the model to evaluation mode
    fig, axs = plt.subplots(2, 5, figsize=(20, 7))
    with torch.no_grad():
        # Concatenate multiple batches to obtain more samples
        test_samples = torch.cat([X for (X,y) in test_loader], dim=0)
        test_samples = test_samples.reshape(len(test_samples), 64)
        # Assuming your model has an `encode` method
        code_samples = model.encode(test_samples)

        # Convert PyTorch tensor to numpy array
        code_np = code_samples.numpy()
        dim_1 = 0
        dim_2 = 1
        for i in range(2):
            for j in range(5):
                # Scatter plot of code distribution
                axs[i,j].scatter(code_np[:, dim_1], code_np[:, dim_2], label='Code Distribution', alpha=0.5)
                axs[i,j].set_xlabel(f"Code Dimension {dim_1}")
                axs[i,j].set_ylabel(f"Code Dimension {dim_2}")
                axs[i,j].set_title(f'Code Distribution: {dim_2}')
                dim_1 += 1
                dim_2 += 1
        plt.tight_layout()
        plt.show()
In [ ]:
input_size = 64
hidden_size = 200
blocks = 10
print_after=1
dataset_percentage = 1.0
batch_size=100

# Instantiate the model
model = RealNVP(input_size=input_size, hidden_size=hidden_size, blocks=blocks)

# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_evaluate(model, data_loader, val_loader, num_epochs=10, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)

# plotting the loss
plot_losses(train_losses, val_losses, want_log_scale=0)
plt.show()

# Example usage:
plot_code_distribution(model=model, test_loader=val_loader)
plt.show()

### plot the synthetic data
synthetic_data=model.sample(num_samples=100)
visualize_synthetic_data(synthetic_data, title="Sampling only important features")


    
Epoch 1/10, Training Loss: -19.166037408212674, Validation Loss: -130.39732536315904
Epoch 2/10, Training Loss: -22.551581888198704, Validation Loss: -138.417531890869
Epoch 3/10, Training Loss: -23.38100530624374, Validation Loss: -142.53491348266584
Epoch 4/10, Training Loss: -23.802392495473065, Validation Loss: -143.01150100707991
Epoch 5/10, Training Loss: -24.102571767171064, Validation Loss: -145.12324249267562
Epoch 6/10, Training Loss: -24.330206867853637, Validation Loss: -146.54541503906233
Epoch 7/10, Training Loss: -24.49192481994612, Validation Loss: -146.0686399841307
Epoch 8/10, Training Loss: -24.635013628005815, Validation Loss: -146.93343902587873
Epoch 9/10, Training Loss: -24.749749383926222, Validation Loss: -148.22393341064435
Epoch 10/10, Training Loss: -24.859632444381543, Validation Loss: -148.3804725646971
Training complete

The results look better than with the digits dataset, but using a bottleneck is still superior in performance and training time

4. Higher-dimensional data with conditional INN¶

Lets continue task 3 with a conditional INN

4.1 Building network and testing Hyperparameters¶

In [ ]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import torch

# Load the digits dataset
digits = load_digits()

X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2)
In [ ]:
#### data for the two-moons model
from torch.utils.data import TensorDataset, DataLoader

# Define a custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y
In [ ]:
# Define model parameters
input_size = 64
hidden_size = 100
condition_size = 10
blocks = 10
percentage = 1.0
num_epochs = 10
lr = 0.005


y_train = torch.arange(condition_size)[y_train].long()
y_test = torch.arange(condition_size)[y_test].long()


# Initialize the model
conditional_inn_model = ConditionalRealNVP(input_size, hidden_size, condition_size, blocks)

train_dataset = CustomDataset(torch.FloatTensor(X_train), y_train)
val_dataset = CustomDataset(torch.FloatTensor(X_test), y_test)

# Define batch size
batch_size = 32

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Task 1: Train the Conditional INN
train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader,
                                                         num_epochs=num_epochs, lr=lr, print_after=1)

# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()

conditions_all_labels = torch.eye(condition_size)
synthetic_data=conditional_inn_model.sample(num_samples=10, conditions= conditions_all_labels)
visualize_synthetic_data(synthetic_data,title="Synthetic digits from 0 to 9")
plt.show()
Epoch 1/10, Training Loss: 14079.221368747287, Validation Loss: 251.96936988830566
Epoch 2/10, Training Loss: 66.6776138305664, Validation Loss: 224.05490080515543
Epoch 3/10, Training Loss: 59.361120012071396, Validation Loss: 205.93127663930258
Epoch 4/10, Training Loss: 55.38476036919488, Validation Loss: 202.0959882736206
Epoch 5/10, Training Loss: 52.79541244506836, Validation Loss: 191.49610010782877
Epoch 6/10, Training Loss: 50.71405766805013, Validation Loss: 193.31329123179117
Epoch 7/10, Training Loss: 48.682631174723305, Validation Loss: 184.10962549845377
Epoch 8/10, Training Loss: 47.73568098280165, Validation Loss: 183.49630864461264
Epoch 9/10, Training Loss: 46.577716488308376, Validation Loss: 179.7266502380371
Epoch 10/10, Training Loss: 45.92713555230035, Validation Loss: 179.190260887146
Training complete
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
c:\Users\luke\OneDrive\Dokumente\UniHeidelberg\Master\Semester3\Generative Neural Networks\code\Exercise_3_GNN_for_science.ipynb Cell 72 line 3
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=28'>29</a> train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader,
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=29'>30</a>                                                          num_epochs=num_epochs, lr=lr, print_after=1)
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=31'>32</a> # plotting the loss
---> <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=32'>33</a> plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=33'>34</a> plt.show()
     <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=35'>36</a> conditions_all_labels = torch.eye(condition_size)

NameError: name 'train_losses' is not defined

The results look pretty good! Lets try different hyperparameters!

In [ ]:
hidden_sizes = [100, 200, 400]
blocks = [2, 5, 10]
input_size = 64

for hidden_size in hidden_sizes:
    for block in blocks:
        print(f"\nTraining for hidden_size={hidden_size}, blocks = {block}")

        # Instantiate the model
        conditional_inn_model = ConditionalRealNVP(input_size, hidden_size, condition_size, block)

        train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader,
                                                         num_epochs=30, lr=lr, print_after=1)

        # plotting the loss
        plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
        plt.show()

        conditions_all_labels = torch.eye(condition_size)
        synthetic_data=conditional_inn_model.sample(num_samples=10, conditions= conditions_all_labels)
        visualize_synthetic_data(synthetic_data,title="Synthetic digits from 0 to 9")
        plt.show()
        
Training for hidden_size=100, blocks = 2
Epoch 1/30, Training Loss: 293.8609095255534, Validation Loss: 522.8373781840006
Epoch 2/30, Training Loss: 137.28045535617406, Validation Loss: 457.2848320007324
Epoch 3/30, Training Loss: 118.8162612915039, Validation Loss: 405.67692375183105
Epoch 4/30, Training Loss: 107.44375440809462, Validation Loss: 390.22615814208984
Epoch 5/30, Training Loss: 100.79131232367621, Validation Loss: 340.93681780497235
Epoch 6/30, Training Loss: 94.5724119398329, Validation Loss: 333.0782648722331
Epoch 7/30, Training Loss: 91.05331556532118, Validation Loss: 323.25561205546063
Epoch 8/30, Training Loss: 89.2670910305447, Validation Loss: 338.40706125895184
Epoch 9/30, Training Loss: 87.65778791639539, Validation Loss: 304.7028503417969
Epoch 10/30, Training Loss: 84.94269849989149, Validation Loss: 308.5491994222005
Epoch 11/30, Training Loss: 83.56937561035156, Validation Loss: 300.5091025034587
Epoch 12/30, Training Loss: 82.23126559787326, Validation Loss: 293.45045534769696
Epoch 13/30, Training Loss: 80.81693674723307, Validation Loss: 300.28919982910156
Epoch 14/30, Training Loss: 80.5521474202474, Validation Loss: 297.10655721028644
Epoch 15/30, Training Loss: 79.80333201090495, Validation Loss: 291.25659370422363
Epoch 16/30, Training Loss: 78.20782521565755, Validation Loss: 306.5996602376302
Epoch 17/30, Training Loss: 77.80365125868056, Validation Loss: 290.4383665720622
Epoch 18/30, Training Loss: 78.24269239637587, Validation Loss: 275.5026200612386
Epoch 19/30, Training Loss: 77.60413360595703, Validation Loss: 281.0929635365804
Epoch 20/30, Training Loss: 76.18750779893664, Validation Loss: 275.5623073577881
Epoch 21/30, Training Loss: 76.36902414957682, Validation Loss: 287.7348397572835
Epoch 22/30, Training Loss: 76.31015421549479, Validation Loss: 277.7603931427002
Epoch 23/30, Training Loss: 74.88749016655817, Validation Loss: 274.15149815877277
Epoch 24/30, Training Loss: 74.62725897894965, Validation Loss: 288.30015627543133
Epoch 25/30, Training Loss: 74.62852376302084, Validation Loss: 266.2942295074463
Epoch 26/30, Training Loss: 72.85536702473958, Validation Loss: 286.1646811167399
Epoch 27/30, Training Loss: 74.18861253526475, Validation Loss: 269.11271413167316
Epoch 28/30, Training Loss: 73.26144222683376, Validation Loss: 276.10282198588055
Epoch 29/30, Training Loss: 72.91037936740452, Validation Loss: 266.6187432607015
Epoch 30/30, Training Loss: 72.74726952446832, Validation Loss: 269.72121556599933
Training complete
Training for hidden_size=100, blocks = 5
Epoch 1/30, Training Loss: 245.61961432562933, Validation Loss: 288.1628958384196
Epoch 2/30, Training Loss: 73.73544447157118, Validation Loss: 253.28568267822266
Epoch 3/30, Training Loss: 65.67341079711915, Validation Loss: 233.32082843780518
Epoch 4/30, Training Loss: 61.25560048421224, Validation Loss: 223.8289836247762
Epoch 5/30, Training Loss: 58.051698981391056, Validation Loss: 212.6532974243164
Epoch 6/30, Training Loss: 55.78813883463542, Validation Loss: 214.10021591186523
Epoch 7/30, Training Loss: 53.83916973537869, Validation Loss: 211.61553732554117
Epoch 8/30, Training Loss: 52.40808283487956, Validation Loss: 207.70124022165933
Epoch 9/30, Training Loss: 51.4053960164388, Validation Loss: 201.84393946329752
Epoch 10/30, Training Loss: 50.25626102023654, Validation Loss: 203.0097745259603
Epoch 11/30, Training Loss: 49.2913705613878, Validation Loss: 202.8919070561727
Epoch 12/30, Training Loss: 48.77919074164497, Validation Loss: 199.91488869984946
Epoch 13/30, Training Loss: 48.111915503607854, Validation Loss: 198.21980253855386
Epoch 14/30, Training Loss: 47.437728034125435, Validation Loss: 199.25023110707602
Epoch 15/30, Training Loss: 47.18081834581163, Validation Loss: 198.07999897003174
Epoch 16/30, Training Loss: 46.162787458631726, Validation Loss: 198.0883830388387
Epoch 17/30, Training Loss: 46.0449708726671, Validation Loss: 197.23108418782553
Epoch 18/30, Training Loss: 44.870973375108505, Validation Loss: 196.50633462270102
Epoch 19/30, Training Loss: 44.54693255954319, Validation Loss: 193.11839516957602
Epoch 20/30, Training Loss: 43.91679543389215, Validation Loss: 196.35790157318115
Epoch 21/30, Training Loss: 43.8365002102322, Validation Loss: 191.82647037506104
Epoch 22/30, Training Loss: 43.50753716362847, Validation Loss: 195.8693552017212
Epoch 23/30, Training Loss: 43.52292590671115, Validation Loss: 191.5731871922811
Epoch 24/30, Training Loss: 42.970313771565756, Validation Loss: 193.97451178232828
Epoch 25/30, Training Loss: 42.77379387749566, Validation Loss: 188.9208027521769
Epoch 26/30, Training Loss: 42.4356568230523, Validation Loss: 191.19274870554605
Epoch 27/30, Training Loss: 41.90263214111328, Validation Loss: 194.66618855794272
Epoch 28/30, Training Loss: 42.02589831882053, Validation Loss: 193.22452799479166
Epoch 29/30, Training Loss: 41.702737511528866, Validation Loss: 186.7782309850057
Epoch 30/30, Training Loss: 41.530994245741105, Validation Loss: 199.4575522740682
Training complete
Training for hidden_size=100, blocks = 10
Epoch 1/30, Training Loss: 1349.3783764309353, Validation Loss: 252.0936533610026
Epoch 2/30, Training Loss: 64.5319458855523, Validation Loss: 225.27423095703125
Epoch 3/30, Training Loss: 57.80359793768989, Validation Loss: 210.5755526224772
Epoch 4/30, Training Loss: 53.993755171034074, Validation Loss: 205.24729283650717
Epoch 5/30, Training Loss: 51.40060594346788, Validation Loss: 198.3798786799113
Epoch 6/30, Training Loss: 49.362469312879774, Validation Loss: 191.2987314860026
Epoch 7/30, Training Loss: 47.77900839911567, Validation Loss: 197.155779838562
Epoch 8/30, Training Loss: 46.82213287353515, Validation Loss: 192.00089104970297
Epoch 9/30, Training Loss: 45.66339738633898, Validation Loss: 186.95856475830078
Epoch 10/30, Training Loss: 44.15901014539931, Validation Loss: 186.4817123413086
Epoch 11/30, Training Loss: 43.3345470852322, Validation Loss: 187.1396245956421
Epoch 12/30, Training Loss: 42.67904968261719, Validation Loss: 188.0238265991211
Epoch 13/30, Training Loss: 41.80822982788086, Validation Loss: 185.34127044677734
Epoch 14/30, Training Loss: 41.55522350735134, Validation Loss: 187.3965581258138
Epoch 15/30, Training Loss: 41.04959403143989, Validation Loss: 184.37189610799155
Epoch 16/30, Training Loss: 40.496132405598956, Validation Loss: 184.7194267908732
Epoch 17/30, Training Loss: 39.74103876749675, Validation Loss: 184.01869996388754
Epoch 18/30, Training Loss: 39.6324717203776, Validation Loss: 182.57812881469727
Epoch 19/30, Training Loss: 38.70297181871202, Validation Loss: 183.27468649546304
Epoch 20/30, Training Loss: 38.544365607367624, Validation Loss: 184.38559182484946
Epoch 21/30, Training Loss: 38.10673684014215, Validation Loss: 188.7575225830078
Epoch 22/30, Training Loss: 38.150281185574, Validation Loss: 182.02412446339926
Epoch 23/30, Training Loss: 37.513829718695746, Validation Loss: 184.8826446533203
Epoch 24/30, Training Loss: 36.985147603352864, Validation Loss: 187.850931485494
Epoch 25/30, Training Loss: 37.024013943142364, Validation Loss: 188.29494953155518
Epoch 26/30, Training Loss: 36.79429465399848, Validation Loss: 181.92155679066977
Epoch 27/30, Training Loss: 36.5680906507704, Validation Loss: 188.49392795562744
Epoch 28/30, Training Loss: 36.31601587931315, Validation Loss: 185.95901012420654
Epoch 29/30, Training Loss: 36.24478285047743, Validation Loss: 185.64122422536215
Epoch 30/30, Training Loss: 35.76986846923828, Validation Loss: 186.4094565709432
Training complete
Training for hidden_size=200, blocks = 2
Epoch 1/30, Training Loss: 251.62340698242187, Validation Loss: 493.00861167907715
Epoch 2/30, Training Loss: 133.07212592230903, Validation Loss: 445.5189565022786
Epoch 3/30, Training Loss: 120.69439256456164, Validation Loss: 412.8376407623291
Epoch 4/30, Training Loss: 111.98785451253255, Validation Loss: 395.0897928873698
Epoch 5/30, Training Loss: 106.65084398057726, Validation Loss: 371.44888496398926
Epoch 6/30, Training Loss: 104.55206366644965, Validation Loss: 375.08094724019367
Epoch 7/30, Training Loss: 100.62608981662326, Validation Loss: 354.244104385376
Epoch 8/30, Training Loss: 101.3826168484158, Validation Loss: 361.2765522003174
Epoch 9/30, Training Loss: 97.17907290988498, Validation Loss: 350.2586212158203
Epoch 10/30, Training Loss: 99.39001719156902, Validation Loss: 375.323211034139
Epoch 11/30, Training Loss: 96.24877268473307, Validation Loss: 342.65224011739093
Epoch 12/30, Training Loss: 95.22671322292751, Validation Loss: 371.6747303009033
Epoch 13/30, Training Loss: 94.0648691813151, Validation Loss: 335.6468200683594
Epoch 14/30, Training Loss: 94.47991282145182, Validation Loss: 333.73937034606934
Epoch 15/30, Training Loss: 92.03536478678386, Validation Loss: 328.69567171732587
Epoch 16/30, Training Loss: 92.69359588623047, Validation Loss: 339.59826405843097
Epoch 17/30, Training Loss: 90.38678927951389, Validation Loss: 330.8305486043294
Epoch 18/30, Training Loss: 90.6170661078559, Validation Loss: 326.24981753031415
Epoch 19/30, Training Loss: 90.67704806857638, Validation Loss: 326.8300698598226
Epoch 20/30, Training Loss: 90.08084496392144, Validation Loss: 345.3441562652588
Epoch 21/30, Training Loss: 90.32333001030815, Validation Loss: 332.39401563008624
Epoch 22/30, Training Loss: 89.01673482259115, Validation Loss: 329.86750348409015
Epoch 23/30, Training Loss: 89.7814205593533, Validation Loss: 328.3523635864258
Epoch 24/30, Training Loss: 88.09744957817925, Validation Loss: 319.090092976888
Epoch 25/30, Training Loss: 87.52293056911893, Validation Loss: 344.37699190775555
Epoch 26/30, Training Loss: 87.02120480007595, Validation Loss: 350.18481890360516
Epoch 27/30, Training Loss: 84.56066385904948, Validation Loss: 311.2099526723226
Epoch 28/30, Training Loss: 85.24506157769098, Validation Loss: 312.0635674794515
Epoch 29/30, Training Loss: 84.36019151475695, Validation Loss: 322.24486605326337
Epoch 30/30, Training Loss: 85.31957227918836, Validation Loss: 306.62286885579425
Training complete
Training for hidden_size=200, blocks = 5
Epoch 1/30, Training Loss: 314.2896986219618, Validation Loss: 290.5458056131999
Epoch 2/30, Training Loss: 74.96810743543837, Validation Loss: 255.79893811543783
Epoch 3/30, Training Loss: 66.97058885362414, Validation Loss: 235.95261446634927
Epoch 4/30, Training Loss: 62.063047112358944, Validation Loss: 225.87391726175943
Epoch 5/30, Training Loss: 58.89331987169054, Validation Loss: 219.84676583607992
Epoch 6/30, Training Loss: 56.11679161919488, Validation Loss: 213.3170550664266
Epoch 7/30, Training Loss: 54.45528928968641, Validation Loss: 207.7753407160441
Epoch 8/30, Training Loss: 52.892186652289496, Validation Loss: 206.37459564208984
Epoch 9/30, Training Loss: 51.78697018093533, Validation Loss: 210.35921986897787
Epoch 10/30, Training Loss: 50.40202967325846, Validation Loss: 202.17945830027261
Epoch 11/30, Training Loss: 49.897619289822046, Validation Loss: 200.61742146809897
Epoch 12/30, Training Loss: 49.18618816799588, Validation Loss: 200.7643254597982
Epoch 13/30, Training Loss: 47.86549309624566, Validation Loss: 199.72708129882812
Epoch 14/30, Training Loss: 47.19406238132053, Validation Loss: 198.47883065541586
Epoch 15/30, Training Loss: 46.79597422281901, Validation Loss: 196.3492390314738
Epoch 16/30, Training Loss: 46.761542850070526, Validation Loss: 196.71438121795654
Epoch 17/30, Training Loss: 45.04479276869032, Validation Loss: 193.24608580271402
Epoch 18/30, Training Loss: 45.82250001695421, Validation Loss: 197.4443629582723
Epoch 19/30, Training Loss: 44.56912511189778, Validation Loss: 195.75810464223227
Epoch 20/30, Training Loss: 44.55307981703017, Validation Loss: 196.81069056193033
Epoch 21/30, Training Loss: 43.780117713080514, Validation Loss: 190.78466955820718
Epoch 22/30, Training Loss: 43.49141337076823, Validation Loss: 193.90233008066812
Epoch 23/30, Training Loss: 43.6044068230523, Validation Loss: 200.35870520273843
Epoch 24/30, Training Loss: 43.11416583591037, Validation Loss: 198.29298496246338
Epoch 25/30, Training Loss: 42.296264224582245, Validation Loss: 196.31312561035156
Epoch 26/30, Training Loss: 42.06908925374349, Validation Loss: 198.18200302124023
Epoch 27/30, Training Loss: 42.08123363918728, Validation Loss: 197.08484395345053
Epoch 28/30, Training Loss: 41.53609085083008, Validation Loss: 196.21040725708008
Epoch 29/30, Training Loss: 41.187508307562936, Validation Loss: 194.32446511586508
Epoch 30/30, Training Loss: 41.22317564222548, Validation Loss: 199.1661138534546
Training complete
Training for hidden_size=200, blocks = 10
Epoch 1/30, Training Loss: 8550.333620876736, Validation Loss: 252.25765419006348
Epoch 2/30, Training Loss: 64.73046518961588, Validation Loss: 221.92402013142905
Epoch 3/30, Training Loss: 57.476726362440324, Validation Loss: 208.8302043279012
Epoch 4/30, Training Loss: 53.80047760009766, Validation Loss: 204.94884077707925
Epoch 5/30, Training Loss: 51.54618767632378, Validation Loss: 199.60729948679605
Epoch 6/30, Training Loss: 49.243514166937935, Validation Loss: 195.52471828460693
Epoch 7/30, Training Loss: 47.38635228474935, Validation Loss: 190.7220137914022
Epoch 8/30, Training Loss: 46.45990176730686, Validation Loss: 198.09132544199625
Epoch 9/30, Training Loss: 45.40678176879883, Validation Loss: 190.19877115885416
Epoch 10/30, Training Loss: 44.32002614339193, Validation Loss: 190.01830673217773
Epoch 11/30, Training Loss: 43.25411470201281, Validation Loss: 188.1542704900106
Epoch 12/30, Training Loss: 42.34379747178819, Validation Loss: 189.91027164459229
Epoch 13/30, Training Loss: 41.8846063401964, Validation Loss: 185.82088088989258
Epoch 14/30, Training Loss: 41.15804562038846, Validation Loss: 190.71087487538657
Epoch 15/30, Training Loss: 40.05203467475043, Validation Loss: 187.0631825129191
Epoch 16/30, Training Loss: 39.5969729953342, Validation Loss: 185.98451582590738
Epoch 17/30, Training Loss: 39.40454610188802, Validation Loss: 187.34200191497803
Epoch 18/30, Training Loss: 39.446832275390626, Validation Loss: 184.95723565419516
Epoch 19/30, Training Loss: 38.229959615071614, Validation Loss: 183.16800848642984
Epoch 20/30, Training Loss: 38.03151304456923, Validation Loss: 187.79945786794028
Epoch 21/30, Training Loss: 37.87414449055989, Validation Loss: 190.2610190709432
Epoch 22/30, Training Loss: 37.65578426784939, Validation Loss: 185.9701935450236
Epoch 23/30, Training Loss: 36.962117513020836, Validation Loss: 188.50027306874594
Epoch 24/30, Training Loss: 36.36226069132487, Validation Loss: 189.6880203882853
Epoch 25/30, Training Loss: 36.388100941975914, Validation Loss: 190.7065750757853
Epoch 26/30, Training Loss: 36.236502668592664, Validation Loss: 191.58643309275308
Epoch 27/30, Training Loss: 35.916646321614586, Validation Loss: 191.9307139714559
Epoch 28/30, Training Loss: 35.55673332214356, Validation Loss: 189.29961744944254
Epoch 29/30, Training Loss: 35.55209545559353, Validation Loss: 189.1108185450236
Epoch 30/30, Training Loss: 35.0932319217258, Validation Loss: 184.86688454945883
Training complete
Training for hidden_size=400, blocks = 2
Epoch 1/30, Training Loss: 309.8460605197483, Validation Loss: 641.5680033365885
Epoch 2/30, Training Loss: 167.27688564724392, Validation Loss: 550.9989280700684
Epoch 3/30, Training Loss: 155.16913079155816, Validation Loss: 552.8562787373861
Epoch 4/30, Training Loss: 142.22390001085068, Validation Loss: 491.44150416056317
Epoch 5/30, Training Loss: 137.04273817274304, Validation Loss: 793.9056180318197
Epoch 6/30, Training Loss: 158.23646189371746, Validation Loss: 612.9892018636068
Epoch 7/30, Training Loss: 138.3919923570421, Validation Loss: 442.6516710917155
Epoch 8/30, Training Loss: 133.79639689127603, Validation Loss: 836.4147109985352
Epoch 9/30, Training Loss: 146.32217983669705, Validation Loss: 468.4177182515462
Epoch 10/30, Training Loss: 137.2965772840712, Validation Loss: 465.1561864217122
Epoch 11/30, Training Loss: 142.33085208468967, Validation Loss: 486.0627187093099
Epoch 12/30, Training Loss: 129.43356441921657, Validation Loss: 419.21056874593097
Epoch 13/30, Training Loss: 122.93208906385634, Validation Loss: 428.186274210612
Epoch 14/30, Training Loss: 132.58959350585937, Validation Loss: 463.64319864908856
Epoch 15/30, Training Loss: 113.37272321912977, Validation Loss: 377.0710964202881
Epoch 16/30, Training Loss: 117.27298194037543, Validation Loss: 417.78064982096356
Epoch 17/30, Training Loss: 129.6254869249132, Validation Loss: 422.8923905690511
Epoch 18/30, Training Loss: 112.74502393934462, Validation Loss: 425.263973236084
Epoch 19/30, Training Loss: 114.95833892822266, Validation Loss: 388.4588680267334
Epoch 20/30, Training Loss: 116.57845085991754, Validation Loss: 398.36245282491046
Epoch 21/30, Training Loss: 107.6398440890842, Validation Loss: 395.6511694590251
Epoch 22/30, Training Loss: 114.73315056694878, Validation Loss: 365.79894574483234
Epoch 23/30, Training Loss: 153.03194698757596, Validation Loss: 548.2154261271158
Epoch 24/30, Training Loss: 120.47216135660807, Validation Loss: 377.80611483256024
Epoch 25/30, Training Loss: 121.93333214653863, Validation Loss: 359.2391185760498
Epoch 26/30, Training Loss: 107.88890940348307, Validation Loss: 388.6443614959717
Epoch 27/30, Training Loss: 105.48415205213759, Validation Loss: 394.25697898864746
Epoch 28/30, Training Loss: 103.53916642930773, Validation Loss: 359.95915667215985
Epoch 29/30, Training Loss: 103.76796654595269, Validation Loss: 404.914644241333
Epoch 30/30, Training Loss: 101.15233103434245, Validation Loss: 371.7300599416097
Training complete
Training for hidden_size=400, blocks = 5
Epoch 1/30, Training Loss: 283.9519036187066, Validation Loss: 306.4883219401042
Epoch 2/30, Training Loss: 78.9733864678277, Validation Loss: 267.76504071553546
Epoch 3/30, Training Loss: 71.24371761745877, Validation Loss: 249.29519907633463
Epoch 4/30, Training Loss: 66.88938734266493, Validation Loss: 242.08064715067545
Epoch 5/30, Training Loss: 64.17731535169813, Validation Loss: 234.6187343597412
Epoch 6/30, Training Loss: 61.6169075012207, Validation Loss: 232.41984430948892
Epoch 7/30, Training Loss: 61.058763631184895, Validation Loss: 227.2152500152588
Epoch 8/30, Training Loss: 58.36835199991862, Validation Loss: 221.61338710784912
Epoch 9/30, Training Loss: 57.72528330485026, Validation Loss: 223.6118532816569
Epoch 10/30, Training Loss: 56.28718982272678, Validation Loss: 223.97171783447266
Epoch 11/30, Training Loss: 56.164582061767575, Validation Loss: 217.72104358673096
Epoch 12/30, Training Loss: 55.18229760064019, Validation Loss: 218.48025862375894
Epoch 13/30, Training Loss: 54.764756774902345, Validation Loss: 226.10987981160483
Epoch 14/30, Training Loss: 54.429725138346356, Validation Loss: 217.99609343210855
Epoch 15/30, Training Loss: 52.93662999471029, Validation Loss: 219.3150078455607
Epoch 16/30, Training Loss: 52.57396011352539, Validation Loss: 219.934596379598
Epoch 17/30, Training Loss: 52.49516406589084, Validation Loss: 221.11824067433676
Epoch 18/30, Training Loss: 51.70450863308377, Validation Loss: 213.79633712768555
Epoch 19/30, Training Loss: 51.20831722683377, Validation Loss: 216.88543923695883
Epoch 20/30, Training Loss: 51.07510291205512, Validation Loss: 217.75438944498697
Epoch 21/30, Training Loss: 50.03879080878364, Validation Loss: 216.98380406697592
Epoch 22/30, Training Loss: 49.473376210530596, Validation Loss: 220.66783332824707
Epoch 23/30, Training Loss: 49.485553656684026, Validation Loss: 218.93354606628418
Epoch 24/30, Training Loss: 49.50805435180664, Validation Loss: 215.00685501098633
Epoch 25/30, Training Loss: 48.39864205254449, Validation Loss: 218.3308277130127
Epoch 26/30, Training Loss: 48.350467936197916, Validation Loss: 217.81539980570474
Epoch 27/30, Training Loss: 48.22176903618707, Validation Loss: 216.50745010375977
Epoch 28/30, Training Loss: 47.75731006198459, Validation Loss: 213.8287207285563
Epoch 29/30, Training Loss: 47.470498402913414, Validation Loss: 216.02937698364258
Epoch 30/30, Training Loss: 46.7009398566352, Validation Loss: 221.29212951660156
Training complete
Training for hidden_size=400, blocks = 10
Epoch 1/30, Training Loss: 1178.0424274020725, Validation Loss: 314.2991199493408
Epoch 2/30, Training Loss: 80.48548092312284, Validation Loss: 272.5304183959961
Epoch 3/30, Training Loss: 71.183130730523, Validation Loss: 252.77461687723795
Epoch 4/30, Training Loss: 67.16154123942057, Validation Loss: 247.56891632080078
Epoch 5/30, Training Loss: 63.57256266276042, Validation Loss: 238.0752493540446
Epoch 6/30, Training Loss: 61.470615810818146, Validation Loss: 232.28548304239908
Epoch 7/30, Training Loss: 59.5168451944987, Validation Loss: 230.8792053858439
Epoch 8/30, Training Loss: 58.22029656304254, Validation Loss: 229.1913324991862
Epoch 9/30, Training Loss: 57.4977183871799, Validation Loss: 223.06117502848306
Epoch 10/30, Training Loss: 56.045141516791446, Validation Loss: 226.03118069966635
Epoch 11/30, Training Loss: 55.24988725450304, Validation Loss: 221.1461903254191
Epoch 12/30, Training Loss: 55.002064853244356, Validation Loss: 221.07627709706625
Epoch 13/30, Training Loss: 53.88930994669597, Validation Loss: 212.35111586252847
Epoch 14/30, Training Loss: 52.65658925374349, Validation Loss: 217.0702174504598
Epoch 15/30, Training Loss: 52.393392605251734, Validation Loss: 215.39574591318765
Epoch 16/30, Training Loss: 51.5170295715332, Validation Loss: 216.01587708791098
Epoch 17/30, Training Loss: 51.10721206665039, Validation Loss: 216.7215223312378
Epoch 18/30, Training Loss: 50.49084082709418, Validation Loss: 214.5624647140503
Epoch 19/30, Training Loss: 50.000826856825086, Validation Loss: 219.19682820638022
Epoch 20/30, Training Loss: 49.75510025024414, Validation Loss: 215.7393010457357
Epoch 21/30, Training Loss: 49.33414815266927, Validation Loss: 217.59172598520914
Epoch 22/30, Training Loss: 49.12450781928168, Validation Loss: 212.6950419743856
Epoch 23/30, Training Loss: 48.87369783189562, Validation Loss: 218.82337061564127
Epoch 24/30, Training Loss: 48.36188820732964, Validation Loss: 214.2415647506714
Epoch 25/30, Training Loss: 48.280346086290145, Validation Loss: 215.3443225224813
Epoch 26/30, Training Loss: 48.694633399115666, Validation Loss: 217.40858713785806
Epoch 27/30, Training Loss: 48.18829659356011, Validation Loss: 221.2189852396647
Epoch 28/30, Training Loss: 47.713622368706595, Validation Loss: 213.31340758005777
Epoch 29/30, Training Loss: 47.00494562784831, Validation Loss: 217.6350638071696
Epoch 30/30, Training Loss: 47.01122521294488, Validation Loss: 212.82092920939127
Training complete

We see we still get the best results for the same network structure, so we are going to stick with that and the same training hyperparameters

4.2 Artificial bottlenecks¶

In [ ]:
### conditional real NVP class
class ConditionalRealNVP_bottleneck(nn.Module):
    def __init__(self, input_size, hidden_size, condition_size, blocks, k):
        """
        Initialize a ConditionalRealNVP model.

        Args:
        - input_size (int): Total size of the input data.
        - hidden_size (int): Size of the hidden layers in the neural networks.
        - condition_size (int): Size of the condition vector (e.g., one-hot encoded label size).
        - blocks (int): Number of coupling layers in the model.
        """
        super(ConditionalRealNVP_bottleneck, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.condition_size = condition_size
        self.blocks = blocks
        self.k = k



        # List of coupling layers
        self.coupling_layers = nn.ModuleList([
            ConditionalCouplingLayer(input_size, hidden_size, condition_size) for _ in range(blocks)
        ])

        # List to store orthonormal matrices
        self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]

        # List to store scaling_before_exp for each block
        self.scaling_before_exp_list = []

    def _get_orthonormal_matrix(self, size):
        """
        Generate a random orthonormal matrix.

        Args:
        - size (int): Size of the matrix.

        Returns:
        - q (torch.Tensor): Orthonormal matrix.
        """
        w = torch.randn(size, size)
        q, _ = torch.linalg.qr(w, 'reduced')
        return q

    def forward_realnvp(self, x, condition):
        """
        Forward pass through the ConditionalRealNVP model.

        Args:
        - x (torch.Tensor): Input data.
        - condition (torch.Tensor): Condition vector.

        Returns:
        - x (torch.Tensor): Transformed data.
        """
        scaling_before_exp_list = []
        for i in range(self.blocks):
            #print("x is:"); print(x)
            #print("shape of x is:"); print(x.shape)
            x = torch.matmul(x, self.orthonormal_matrices[i])
            x, scaling_before_exp = self.coupling_layers[i].forward(x, condition)
            scaling_before_exp_list.append(scaling_before_exp)

        self.scaling_before_exp_list = scaling_before_exp_list
        return x


    def decode(self, z, condition):
        # Modify z to zero out dimensions beyond k for the reconstruction
        z_reconstructed = z.clone()
        if self.k < self.input_size:
            z_reconstructed[:, self.k:] = 0  # Zero out dimensions beyond k

         # Proceed with the original decoding process
        for i in reversed(range(self.blocks)):
            z = self.coupling_layers[i].backward(z, condition)
            z_reconstructed = self.coupling_layers[i].backward(z_reconstructed, condition)
            z = torch.matmul(z, self.orthonormal_matrices[i].t())
            z_reconstructed = torch.matmul(z_reconstructed, self.orthonormal_matrices[i].t())

        return z, z_reconstructed
    
    def sample(self, num_samples=1000, conditions=None):
        """
        Generate synthetic samples.

        Args:
        - num_samples (int): Number of synthetic samples to generate.
        - conditions (torch.Tensor): Conditions for generating synthetic samples.

        Returns:
        - synthetic_samples (torch.Tensor): Synthetic samples.
        """
        with torch.no_grad():
            z = torch.randn(num_samples, self.input_size)
            synthetic_samples, _ = self.decode(z, conditions)
        return synthetic_samples

    
    def sample_only_important(self, num_samples=1000, conditions=None):
        # Generate random samples from a standard normal distribution
        with torch.no_grad():
            z_1 = torch.randn(num_samples, self.k)
            z_2 = torch.zeros(num_samples, self.input_size - self.k)
            z = torch.cat((z_1, z_2), dim=1)

        # Apply the reverse transformations (decoder) to generate synthetic samples
        synthetic_samples, _ = self.decode(z, conditions)
        return synthetic_samples
    
    def sample_only_unimportant(self, num_samples=1000, conditions=None):
        # Generate random samples from a standard normal distribution

        with torch.no_grad():
            z_1 = torch.randn(1, self.k).repeat(num_samples, 1)
            z_2 = torch.randn(num_samples, self.input_size - self.k)
            z = torch.cat((z_1, z_2), dim=1)

        # Apply the reverse transformations (decoder) to generate synthetic samples
        synthetic_samples, _ = self.decode(z, conditions)
        return synthetic_samples
In [ ]:
### training_the_conditional_nvp model

import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def train_and_validate_conditional_nvp_bottleneck(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
    """
    Train the ConditionalRealNVP model and evaluate on a validation dataset.

    Args:
    - model (ConditionalRealNVP): The ConditionalRealNVP model to be trained.
    - train_loader (DataLoader): DataLoader for the training dataset.
    - val_loader (DataLoader): DataLoader for the validation dataset.
    - num_epochs (int): Number of training epochs.
    - lr (float): Learning rate for the optimizer.
    - print_after (int): Number of epochs after which to print the training and validation loss.

    Returns:
    - train_losses (list): List of training losses for each epoch.
    - val_losses (list): List of validation losses for each epoch.
    """

    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    train_losses_nll = [] 
    val_losses_nll = []    
    train_losses_recons = []  
    val_losses_recons = []    
    # Training phase
    model.train()  # Set the model to training mode
    
    for epoch in range(num_epochs):
        total_train_loss_nll = 0.0
        total_train_loss_recons = 0.0

        for data, labels in train_loader:
            inputs = data
            conditions = one_hot(labels, num_classes=model.condition_size).float()

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass (encoding)
            encoded = model.forward_realnvp(inputs, conditions)
            
            train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
            
            # Reconstruction loss calculation
            _, decoded = model.decode(encoded, conditions)
            train_loss_recons = mse_loss(inputs, decoded)
            
            # Backward pass (gradient computation)
            loss = train_loss_nll + train_loss_recons
            loss.backward()

            ### added recently: clip the gradients
            clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            # Update weights
            optimizer.step()

            total_train_loss_nll += train_loss_nll.item()
            total_train_loss_recons += train_loss_recons.item()

        # Average training loss for the epoch
        average_train_loss_nll = total_train_loss_nll / len(train_loader)
        average_train_loss_recons = total_train_loss_recons / len(train_loader)

        # Validation phase
        model.eval()  # Set the model to evaluation mode
        if val_loader is not None:
            total_val_loss_nll = 0.0
            total_val_loss_recons = 0.0
            with torch.no_grad():
                for val_data, val_labels in val_loader:
                    val_inputs = val_data
                    val_conditions = one_hot(val_labels, num_classes=model.condition_size).float()

                    # Forward pass (encoding) for validation
                    val_encoded = model.forward_realnvp(val_inputs, val_conditions)

                    # NLL Loss calculation
                    val_loss_nll = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))
                    
                    # Reconstruction loss calculation
                    _, decoded = model.decode(val_encoded, val_conditions)
                    val_loss_recons = mse_loss(val_inputs, decoded)

                    total_val_loss_nll += val_loss_nll.item()
                    total_val_loss_recons += val_loss_recons.item()

            # Average validation loss for the epoch
            average_val_loss_nll = total_val_loss_nll / len(val_loader)
            average_val_loss_recons = total_val_loss_recons / len(val_loader)

            # Print training and validation losses together
            if (epoch + 1) % print_after == 0:
                print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")

            # Append losses to the lists
            train_losses_nll.append(average_train_loss_nll)
            val_losses_nll.append(average_val_loss_nll)
            train_losses_recons.append(average_train_loss_recons)
            val_losses_recons.append(average_val_loss_recons)

    print("Training complete")

    return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
In [ ]:
k_values = [2,4,8]
dataset_percentage = 1.0

for k in k_values:
    print(f"\nTraining for k={k}")

    # Instantiate the model
    model = ConditionalRealNVP_bottleneck(input_size=64, hidden_size=100, blocks=10,condition_size=10,k=k)

    # Train the model
    train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_validate_conditional_nvp_bottleneck(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
    train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
    val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
    # plotting the loss
    plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
    plt.show()

    ### plot the synthetic data
    conditions_all_labels = torch.eye(condition_size)
    synthetic_data=model.sample_only_important(num_samples=10,conditions= conditions_all_labels)
    visualize_synthetic_data(synthetic_data, title="Sampling only important features")

    conditions_first_elements = torch.zeros((10, 10))
    conditions_first_elements[:,0] = 1
    synthetic_data=model.sample_only_unimportant(num_samples=10, conditions=conditions_first_elements)
    visualize_synthetic_data(synthetic_data, title="Sampling only unimportant features")
    
Training for k=2
Epoch 1/20, Training Loss: 4268.482129012214, Validation Loss: 277.0808455944061
Epoch 2/20, Training Loss: 79.93924753401015, Validation Loss: 238.11668674151102
Epoch 3/20, Training Loss: 70.49177271525065, Validation Loss: 220.5944479306539
Epoch 4/20, Training Loss: 65.97008811102974, Validation Loss: 212.06634493668875
Epoch 5/20, Training Loss: 62.57478442721897, Validation Loss: 206.0178082784017
Epoch 6/20, Training Loss: 60.040938536326095, Validation Loss: 206.48840761184692
Epoch 7/20, Training Loss: 58.67695316738553, Validation Loss: 199.03116782506308
Epoch 8/20, Training Loss: 56.43557636472914, Validation Loss: 194.44701397418976
Epoch 9/20, Training Loss: 55.055921967824304, Validation Loss: 194.8996948401133
Epoch 10/20, Training Loss: 53.81541895336575, Validation Loss: 192.07523147265118
Epoch 11/20, Training Loss: 52.73026489681668, Validation Loss: 193.12475633621216
Epoch 12/20, Training Loss: 52.082277584075925, Validation Loss: 188.29537717501324
Epoch 13/20, Training Loss: 51.161164898342555, Validation Loss: 191.1772176027298
Epoch 14/20, Training Loss: 50.40476257536147, Validation Loss: 187.4039184252421
Epoch 15/20, Training Loss: 49.79507061640422, Validation Loss: 189.2285165389379
Epoch 16/20, Training Loss: 48.9830634329054, Validation Loss: 187.52914067109427
Epoch 17/20, Training Loss: 48.31131310992771, Validation Loss: 189.55784090360004
Epoch 18/20, Training Loss: 47.838651455773245, Validation Loss: 183.39785277843475
Epoch 19/20, Training Loss: 47.57893569734362, Validation Loss: 185.94831204414368
Epoch 20/20, Training Loss: 46.873543463812936, Validation Loss: 185.71018425623575
Training complete
Training for k=4
Epoch 1/20, Training Loss: 1531363.1183485244, Validation Loss: 1611.2183237075806
Epoch 2/20, Training Loss: 145.63676013946534, Validation Loss: 285.2107837994894
Epoch 3/20, Training Loss: 81.61559757656522, Validation Loss: 245.5362807114919
Epoch 4/20, Training Loss: 71.20380795796711, Validation Loss: 225.57799275716147
Epoch 5/20, Training Loss: 65.88819755978054, Validation Loss: 213.04468441009521
Epoch 6/20, Training Loss: 61.858477687835695, Validation Loss: 206.82857819398242
Epoch 7/20, Training Loss: 59.467446655697294, Validation Loss: 203.04642629623413
Epoch 8/20, Training Loss: 56.79460786183675, Validation Loss: 198.47188929716745
Epoch 9/20, Training Loss: 55.198437235090466, Validation Loss: 194.5198189020157
Epoch 10/20, Training Loss: 53.73772185643514, Validation Loss: 194.05041245619455
Epoch 11/20, Training Loss: 52.30835009680854, Validation Loss: 189.8688794374466
Epoch 12/20, Training Loss: 51.18810695012411, Validation Loss: 188.23876953125
Epoch 13/20, Training Loss: 50.2093313852946, Validation Loss: 185.54253792762756
Epoch 14/20, Training Loss: 49.52944510777791, Validation Loss: 186.65644093354544
Epoch 15/20, Training Loss: 48.634717538621686, Validation Loss: 184.9084700345993
Epoch 16/20, Training Loss: 48.00103391011556, Validation Loss: 186.0170479218165
Epoch 17/20, Training Loss: 47.37019804848565, Validation Loss: 185.29169126351675
Epoch 18/20, Training Loss: 46.76975990931193, Validation Loss: 182.8655904928843
Epoch 19/20, Training Loss: 45.86354028913709, Validation Loss: 186.44027853012085
Epoch 20/20, Training Loss: 46.205892425113255, Validation Loss: 184.22301808993024
Training complete
Training for k=8
Epoch 1/20, Training Loss: 12503657.154557291, Validation Loss: 57410.58361816406
Epoch 2/20, Training Loss: 2766.1257029215494, Validation Loss: 341.8003800710042
Epoch 3/20, Training Loss: 92.67972922854953, Validation Loss: 261.6807294686635
Epoch 4/20, Training Loss: 75.06449168523153, Validation Loss: 230.45265928904217
Epoch 5/20, Training Loss: 67.8765141805013, Validation Loss: 218.10639572143555
Epoch 6/20, Training Loss: 63.224973074595134, Validation Loss: 211.41110841433206
Epoch 7/20, Training Loss: 60.210762935214575, Validation Loss: 208.3572313785553
Epoch 8/20, Training Loss: 58.03645001517402, Validation Loss: 199.64385946591696
Epoch 9/20, Training Loss: 56.108029672834604, Validation Loss: 195.0037250916163
Epoch 10/20, Training Loss: 54.359071456061464, Validation Loss: 193.970672527949
Epoch 11/20, Training Loss: 53.008008819156224, Validation Loss: 192.40503108501434
Epoch 12/20, Training Loss: 51.81032321718004, Validation Loss: 190.76259569327038
Epoch 13/20, Training Loss: 50.47278265423245, Validation Loss: 186.2374519109726
Epoch 14/20, Training Loss: 49.65340761608548, Validation Loss: 189.4589294989904
Epoch 15/20, Training Loss: 49.08715744548374, Validation Loss: 185.34620114167532
Epoch 16/20, Training Loss: 48.26118794547187, Validation Loss: 183.30718386173248
Epoch 17/20, Training Loss: 47.3154987970988, Validation Loss: 183.28789893786114
Epoch 18/20, Training Loss: 46.86230368084377, Validation Loss: 185.00911617279053
Epoch 19/20, Training Loss: 46.01728831926982, Validation Loss: 182.14698894818622
Epoch 20/20, Training Loss: 46.02813222673204, Validation Loss: 183.1218525568644
Training complete
In [ ]:
# Instantiate the model
model = ConditionalRealNVP_bottleneck(input_size=64, hidden_size=100, blocks=10,condition_size=10,k=2)

# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_validate_conditional_nvp_bottleneck(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
Epoch 1/20, Training Loss: 21580.730417421128, Validation Loss: 297.31214563051856
Epoch 2/20, Training Loss: 83.47490584055582, Validation Loss: 243.01717726389566
Epoch 3/20, Training Loss: 71.68545961380005, Validation Loss: 224.35330470403036
Epoch 4/20, Training Loss: 65.9989339404636, Validation Loss: 214.9265724023183
Epoch 5/20, Training Loss: 62.00650640063815, Validation Loss: 203.16986227035522
Epoch 6/20, Training Loss: 59.44165139728122, Validation Loss: 200.70353790124258
Epoch 7/20, Training Loss: 57.68972770902845, Validation Loss: 196.80060239632925
Epoch 8/20, Training Loss: 55.89540031221178, Validation Loss: 196.14737010002136
Epoch 9/20, Training Loss: 54.81338379118178, Validation Loss: 193.17052300771078
Epoch 10/20, Training Loss: 53.11203460693359, Validation Loss: 190.24123994509378
Epoch 11/20, Training Loss: 52.03209048377143, Validation Loss: 187.1682772239049
Epoch 12/20, Training Loss: 51.64189916186862, Validation Loss: 189.73837987581888
Epoch 13/20, Training Loss: 50.37107356389364, Validation Loss: 184.49492673079175
Epoch 14/20, Training Loss: 49.83295380274455, Validation Loss: 184.8419489065806
Epoch 15/20, Training Loss: 49.164054001702205, Validation Loss: 184.1409958998362
Epoch 16/20, Training Loss: 49.043845907847086, Validation Loss: 185.45416287581125
Epoch 17/20, Training Loss: 48.46328016916911, Validation Loss: 184.55666601657867
Epoch 18/20, Training Loss: 47.334013080596925, Validation Loss: 182.43773651123047
Epoch 19/20, Training Loss: 47.01087157991197, Validation Loss: 185.98060782750449
Epoch 20/20, Training Loss: 46.53891531626384, Validation Loss: 187.26162362098694
Training complete

Here the results from k = 2 look the best. We see that the generated images look like the numbers we tried to generate. Lets test the quality with a Random forest classifier

In [ ]:
from sklearn.ensemble import RandomForestClassifier
rf_classifier = RandomForestClassifier(n_estimators=100)
rf_classifier.fit(X_train, y_train)
Out[ ]:
RandomForestClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier()
In [ ]:
### plot the synthetic data
from sklearn.metrics import accuracy_score
ground_truth = torch.eye(condition_size).repeat(100,1)
ground_truth_labels = np.argmax(ground_truth, axis=1)
synthetic_data=model.sample(num_samples=1000,conditions=ground_truth)
print(synthetic_data.size())
predictions = rf_classifier.predict(synthetic_data)
accuracy = accuracy_score(ground_truth_labels, predictions)
print(f'Accuracy = {accuracy}')
torch.Size([1000, 64])
Accuracy = 0.868

We see the Random Forest Classifier achieves a reasonable accuracy of 0.87!

4.3 test¶

In [ ]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter # install 'pillow ' to get PIL
import matplotlib.pyplot as plt
# define a functor to downsample images
class DownsampleTransform:
    def __init__ (self , target_shape , algorithm = Image.Resampling.LANCZOS):
        self.width, self.height = target_shape
        self.algorithm = algorithm
    def __call__ (self , img):
        img = img.resize(( self.width +2, self.height +2) , self.algorithm )
        img = img.crop((1 , 1, self.width +1, self.height +1))
        return img
# concatenate a few transforms
transform = transforms.Compose ([
    DownsampleTransform(target_shape =(8 ,8)),
    transforms.Grayscale(num_output_channels =1) ,
    transforms.ToTensor()
])
# download MNIST
mnist_dataset = datasets.MNIST( root ='./data', train =True ,
                                  transform = transform, download = True )
# create a DataLoader that serves minibatches of size 100
data_loader = DataLoader(mnist_dataset , batch_size =100 , shuffle = True )

mnist_test_dataset = datasets.MNIST( root ='./data', train =False ,
                                  transform = transform, download = True )
val_loader = DataLoader(mnist_test_dataset , batch_size =100 , shuffle = True )
# visualize the first batch of downsampled MNIST images
def show_first_batch(data_loader):
    for batch in data_loader:
        x, y = batch
        fig = plt.figure(figsize =(10 , 10))
        for i, img in enumerate(x):
            ax = fig.add_subplot(10 , 10, i+1)
            ax.imshow(img.reshape(8, 8), cmap ='gray')
            ax.axis('off')
        break



show_first_batch(data_loader)
In [ ]:
### training_the_conditional_nvp model

import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def train_and_validate_conditional_nvp(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
    """
    Train the ConditionalRealNVP model and evaluate on a validation dataset.

    Args:
    - model (ConditionalRealNVP): The ConditionalRealNVP model to be trained.
    - train_loader (DataLoader): DataLoader for the training dataset.
    - val_loader (DataLoader): DataLoader for the validation dataset.
    - num_epochs (int): Number of training epochs.
    - lr (float): Learning rate for the optimizer.
    - print_after (int): Number of epochs after which to print the training and validation loss.

    Returns:
    - train_losses (list): List of training losses for each epoch.
    - val_losses (list): List of validation losses for each epoch.
    """

    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    mse_loss = nn.MSELoss()

    train_losses_nll = [] 
    val_losses_nll = []    
    train_losses_recons = []  
    val_losses_recons = []    
    # Training phase
    model.train()  # Set the model to training mode
    
    for epoch in range(num_epochs):
        total_train_loss_nll = 0.0
        total_train_loss_recons = 0.0

        for data, labels in train_loader:
            inputs = data.reshape(len(labels),64)
            conditions = one_hot(labels, num_classes=model.condition_size).float()

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass (encoding)
            encoded = model.forward_realnvp(inputs, conditions)
            
            train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
            
            # Reconstruction loss calculation
            decoded = model.decode(encoded, conditions)
            train_loss_recons = mse_loss(inputs, decoded)
            
            # Backward pass (gradient computation)
            loss = train_loss_nll + train_loss_recons
            loss.backward()

            ### added recently: clip the gradients
            clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            # Update weights
            optimizer.step()

            total_train_loss_nll += train_loss_nll.item()
            total_train_loss_recons += train_loss_recons.item()

        # Average training loss for the epoch
        average_train_loss_nll = total_train_loss_nll / len(train_loader)
        average_train_loss_recons = total_train_loss_recons / len(train_loader)

        # Validation phase
        model.eval()  # Set the model to evaluation mode
        if val_loader is not None:
            total_val_loss_nll = 0.0
            total_val_loss_recons = 0.0
            with torch.no_grad():
                for val_data, val_labels in val_loader:
                    val_inputs = val_data.reshape(len(labels),64)
                    val_conditions = one_hot(val_labels, num_classes=model.condition_size).float()

                    # Forward pass (encoding) for validation
                    val_encoded = model.forward_realnvp(val_inputs, val_conditions)

                    # NLL Loss calculation
                    val_loss_nll = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))
                    
                    # Reconstruction loss calculation
                    decoded = model.decode(val_encoded, val_conditions)
                    val_loss_recons = mse_loss(val_inputs, decoded)

                    total_val_loss_nll += val_loss_nll.item()
                    total_val_loss_recons += val_loss_recons.item()

            # Average validation loss for the epoch
            average_val_loss_nll = total_val_loss_nll / len(val_loader)
            average_val_loss_recons = total_val_loss_recons / len(val_loader)

            # Print training and validation losses together
            if (epoch + 1) % print_after == 0:
                print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")

            # Append losses to the lists
            train_losses_nll.append(average_train_loss_nll)
            val_losses_nll.append(average_val_loss_nll)
            train_losses_recons.append(average_train_loss_recons)
            val_losses_recons.append(average_val_loss_recons)

    print("Training complete")

    return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
In [ ]:
input_size = 64
hidden_size = 100
blocks = 10
print_after=1
dataset_percentage = 1.0
batch_size=100

# Instantiate the model
model = ConditionalRealNVP(input_size=64, hidden_size=100, blocks=10,condition_size=10)

# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_validate_conditional_nvp(model, data_loader, val_loader, num_epochs=10, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)

# plotting the loss
plot_losses(train_losses, val_losses, want_log_scale=0)
plt.show()


### plot the synthetic data
conditions_all_labels = torch.eye(condition_size)
synthetic_data=conditional_inn_model.sample(num_samples=10, conditions= conditions_all_labels)
visualize_synthetic_data(synthetic_data,title="Synthetic digits from 0 to 9")
plt.show()
Epoch 1/10, Training Loss: -19.08519719739742, Validation Loss: -130.75578613281235
Epoch 2/10, Training Loss: -22.443291050592926, Validation Loss: -137.0801293945311
Epoch 3/10, Training Loss: -23.36428718566881, Validation Loss: -141.2957122802733
Epoch 4/10, Training Loss: -23.87661533991482, Validation Loss: -144.08587219238268
Epoch 5/10, Training Loss: -24.201026268005233, Validation Loss: -145.4849877929686
Epoch 6/10, Training Loss: -24.45872751235948, Validation Loss: -147.0356079101561
Epoch 7/10, Training Loss: -24.656676244735575, Validation Loss: -147.69472976684557
Epoch 8/10, Training Loss: -24.829769868850565, Validation Loss: -149.69698562622057
Epoch 9/10, Training Loss: -24.958654168446717, Validation Loss: -150.0165017700194
Epoch 10/10, Training Loss: -25.043524036407327, Validation Loss: -149.860236968994
Training complete

This using the MNIST dataset yields far worse results than the bottleneck approach,but this might be because longer training is needed.